quocanh34 commited on
Commit
3773ad2
·
1 Parent(s): 9ab61ae

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CodeFormer/CodeFormer/weights/CodeFormer/your_models.txt +0 -0
  2. CodeFormer/CodeFormer/weights/facelib/your_models.txt +0 -0
  3. CodeFormer/CodeFormer/weights/realesrgan/your_models.txt +0 -0
  4. checkpoints/models/your_models.txt +0 -0
  5. post_process/checkpoints/your_models.txt +0 -0
  6. post_process/face_inpaint/inpaint.py +479 -0
  7. post_process/inswapper/CodeFormer/CodeFormer/README.md +123 -0
  8. post_process/inswapper/CodeFormer/CodeFormer/assets/CodeFormer_logo.png +0 -0
  9. post_process/inswapper/CodeFormer/CodeFormer/assets/color_enhancement_result1.png +0 -0
  10. post_process/inswapper/CodeFormer/CodeFormer/assets/color_enhancement_result2.png +0 -0
  11. post_process/inswapper/CodeFormer/CodeFormer/assets/inpainting_result1.png +0 -0
  12. post_process/inswapper/CodeFormer/CodeFormer/assets/inpainting_result2.png +0 -0
  13. post_process/inswapper/CodeFormer/CodeFormer/assets/network.jpg +0 -0
  14. post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result1.png +0 -0
  15. post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result2.png +0 -0
  16. post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result3.png +0 -0
  17. post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result4.png +0 -0
  18. post_process/inswapper/CodeFormer/CodeFormer/basicsr/VERSION +1 -0
  19. post_process/inswapper/CodeFormer/CodeFormer/basicsr/__init__.py +11 -0
  20. post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/__init__.cpython-310.pyc +0 -0
  21. post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/train.cpython-310.pyc +0 -0
  22. post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/version.cpython-310.pyc +0 -0
  23. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__init__.py +25 -0
  24. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/__init__.cpython-310.pyc +0 -0
  25. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/arcface_arch.cpython-310.pyc +0 -0
  26. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/arch_util.cpython-310.pyc +0 -0
  27. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/codeformer_arch.cpython-310.pyc +0 -0
  28. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc +0 -0
  29. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc +0 -0
  30. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/vqgan_arch.cpython-310.pyc +0 -0
  31. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/arcface_arch.py +245 -0
  32. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/arch_util.py +318 -0
  33. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/codeformer_arch.py +276 -0
  34. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/rrdbnet_arch.py +119 -0
  35. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/vgg_arch.py +161 -0
  36. post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/vqgan_arch.py +435 -0
  37. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__init__.py +100 -0
  38. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/__init__.cpython-310.pyc +0 -0
  39. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/data_sampler.cpython-310.pyc +0 -0
  40. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc +0 -0
  41. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/data_sampler.py +48 -0
  42. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/data_util.py +305 -0
  43. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/prefetch_dataloader.py +125 -0
  44. post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/transforms.py +165 -0
  45. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__init__.py +26 -0
  46. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  47. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/loss_util.cpython-310.pyc +0 -0
  48. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/losses.cpython-310.pyc +0 -0
  49. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/loss_util.py +95 -0
  50. post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/losses.py +455 -0
CodeFormer/CodeFormer/weights/CodeFormer/your_models.txt ADDED
File without changes
CodeFormer/CodeFormer/weights/facelib/your_models.txt ADDED
File without changes
CodeFormer/CodeFormer/weights/realesrgan/your_models.txt ADDED
File without changes
checkpoints/models/your_models.txt ADDED
File without changes
post_process/checkpoints/your_models.txt ADDED
File without changes
post_process/face_inpaint/inpaint.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Union, Optional, Tuple
3
+ from enum import IntEnum
4
+ import os
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw, ImageFilter, ImageOps
9
+ from torchvision.transforms.functional import to_pil_image
10
+ # import math
11
+ from diffusers import StableDiffusionInpaintPipeline
12
+ # from post_process.yoloface.face_detector import YoloDetector
13
+
14
+
15
+ MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"]
16
+
17
+
18
+ def adetailer(sd_pipeline, yolodetector, images: list[Image.Image], prompt, negative_prompt, seed=42):
19
+ resolution = 512
20
+ # ad_model = "post_process/yoloface/weights/yolov5n-face.pt"
21
+ processed_input_imgs = []
22
+ for input_image in images:
23
+ pred = ultralytics_predict(yolodetector_model=yolodetector, image=input_image)
24
+ masks = pred_preprocessing(pred)
25
+ for i_mask, mask in enumerate(masks):
26
+ # # Only inpaint up to n faces
27
+ # if i_mask == n:
28
+ # break
29
+ blurred_mask = mask.filter(ImageFilter.GaussianBlur(8))
30
+ crop_region = get_crop_region(np.array(blurred_mask))
31
+ crop_region = expand_crop_region(crop_region, resolution, resolution, mask.width, mask.height)
32
+ x1, y1, x2, y2 = crop_region
33
+ paste_to = (x1, y1, x2-x1, y2-y1)
34
+ image_mask = blurred_mask.crop(crop_region)
35
+ image_mask = image_mask.resize((resolution, resolution), Image.LANCZOS)
36
+
37
+ image_masked = Image.new('RGBa', (input_image.width, input_image.height))
38
+ image_masked.paste(input_image.convert("RGBA"), mask=ImageOps.invert(blurred_mask.convert('L')))
39
+ overlay_image = image_masked.convert('RGBA')
40
+
41
+ patch_input_img = input_image.crop(crop_region)
42
+ patch_input_img = patch_input_img.resize((resolution, resolution), Image.LANCZOS)
43
+ processed_input_imgs.append([patch_input_img, paste_to, overlay_image])
44
+
45
+ denoising_strength = 0.4
46
+
47
+ pipe = StableDiffusionInpaintPipeline(
48
+ vae=sd_pipeline.vae,
49
+ text_encoder=sd_pipeline.text_encoder,
50
+ tokenizer=sd_pipeline.tokenizer,
51
+ unet=sd_pipeline.unet,
52
+ scheduler=sd_pipeline.scheduler,
53
+ requires_safety_checker=False,
54
+ safety_checker=None,
55
+ feature_extractor=sd_pipeline.feature_extractor,
56
+ ).to('cuda')
57
+
58
+ generator = torch.Generator(device="cuda").manual_seed(seed)
59
+
60
+ inpaint_images = []
61
+ for i in range(len(processed_input_imgs)):
62
+ out = pipe(
63
+ prompt=prompt,
64
+ negative_prompt=negative_prompt,
65
+ image=[processed_input_imgs[i][0]],
66
+ mask_image=image_mask,
67
+ num_inference_steps=30,
68
+ strength=denoising_strength,
69
+ controlnet_conditioning_scale=1.0,
70
+ generator=generator
71
+ ).images[0]
72
+
73
+ paste_to = processed_input_imgs[i][1]
74
+ overlay_image = processed_input_imgs[i][2]
75
+
76
+ input_image = apply_overlay(out, paste_to, overlay_image)
77
+ inpaint_images.append(input_image)
78
+
79
+ return inpaint_images
80
+
81
+
82
+ def get_crop_region(mask, pad=0):
83
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
84
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
85
+
86
+ h, w = mask.shape
87
+
88
+ crop_left = 0
89
+ for i in range(w):
90
+ if not (mask[:, i] == 0).all():
91
+ break
92
+ crop_left += 1
93
+
94
+ crop_right = 0
95
+ for i in reversed(range(w)):
96
+ if not (mask[:, i] == 0).all():
97
+ break
98
+ crop_right += 1
99
+
100
+ crop_top = 0
101
+ for i in range(h):
102
+ if not (mask[i] == 0).all():
103
+ break
104
+ crop_top += 1
105
+
106
+ crop_bottom = 0
107
+ for i in reversed(range(h)):
108
+ if not (mask[i] == 0).all():
109
+ break
110
+ crop_bottom += 1
111
+
112
+ return (
113
+ int(max(crop_left-pad, 0)),
114
+ int(max(crop_top-pad, 0)),
115
+ int(min(w - crop_right + pad, w)),
116
+ int(min(h - crop_bottom + pad, h))
117
+ )
118
+
119
+ def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
120
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
121
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
122
+
123
+ x1, y1, x2, y2 = crop_region
124
+
125
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
126
+ ratio_processing = processing_width / processing_height
127
+
128
+ if ratio_crop_region > ratio_processing:
129
+ desired_height = (x2 - x1) / ratio_processing
130
+ desired_height_diff = int(desired_height - (y2-y1))
131
+ y1 -= desired_height_diff//2
132
+ y2 += desired_height_diff - desired_height_diff//2
133
+ if y2 >= image_height:
134
+ diff = y2 - image_height
135
+ y2 -= diff
136
+ y1 -= diff
137
+ if y1 < 0:
138
+ y2 -= y1
139
+ y1 -= y1
140
+ if y2 >= image_height:
141
+ y2 = image_height
142
+ else:
143
+ desired_width = (y2 - y1) * ratio_processing
144
+ desired_width_diff = int(desired_width - (x2-x1))
145
+ x1 -= desired_width_diff//2
146
+ x2 += desired_width_diff - desired_width_diff//2
147
+ if x2 >= image_width:
148
+ diff = x2 - image_width
149
+ x2 -= diff
150
+ x1 -= diff
151
+ if x1 < 0:
152
+ x2 -= x1
153
+ x1 -= x1
154
+ if x2 >= image_width:
155
+ x2 = image_width
156
+
157
+ return x1, y1, x2, y2
158
+
159
+ @dataclass
160
+ class PredictOutput:
161
+ bboxes: List[List[Union[int, float]]] = field(default_factory=list)
162
+ masks: List[Image.Image] = field(default_factory=list)
163
+ preview: Optional[Image.Image] = None
164
+
165
+ def create_mask_from_bbox(
166
+ bboxes: List[List[float]], shape: Tuple[int, int]
167
+ ) -> List[Image.Image]:
168
+ """
169
+ Parameters
170
+ ----------
171
+ bboxes: List[List[float]]
172
+ list of [x1, y1, x2, y2]
173
+ bounding boxes
174
+ shape: Tuple[int, int]
175
+ shape of the image (width, height)
176
+
177
+ Returns
178
+ -------
179
+ masks: List[Image.Image]
180
+ A list of masks
181
+
182
+ """
183
+ masks = []
184
+ for bbox in bboxes:
185
+ mask = Image.new("L", shape, 0)
186
+ mask_draw = ImageDraw.Draw(mask)
187
+ mask_draw.rectangle(bbox, fill=255)
188
+ masks.append(mask)
189
+ return masks
190
+
191
+ def ultralytics_predict(
192
+ # model_path: str,
193
+ yolodector_model,
194
+ image: Image.Image,
195
+ confidence: float = 0.5,
196
+ device: str = "cuda",
197
+ ) -> PredictOutput:
198
+ # model = YoloDetector(target_size=720, device=device, min_face=50)
199
+ bboxes, _ = yolodector_model.predict(np.array(image), conf_thres=confidence, iou_thres=0.5)
200
+ masks = create_mask_from_bbox(bboxes[0], image.size)
201
+
202
+ # model = YOLO(model_path) #old
203
+ # pred = model(image, conf=confidence, device=device) #old
204
+ # bboxes = pred[0].boxes.xyxy.cpu().numpy() #old
205
+ # if bboxes.size == 0:
206
+ # return PredictOutput()
207
+ # bboxes = bboxes.tolist()
208
+
209
+ # if pred[0].masks is None: #old
210
+ # masks = create_mask_from_bbox(bboxes, image.size) #old
211
+ # else: #old
212
+ # masks = mask_to_pil(pred[0].masks.data, image.size) #old
213
+ # preview = pred[0].plot() #old
214
+ # preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB) #old
215
+ # preview = Image.fromarray(preview) #old
216
+
217
+ return PredictOutput(bboxes=bboxes[0], masks=masks, preview=image)
218
+
219
+ def mask_to_pil(masks, shape: Tuple[int, int]) -> List[Image.Image]:
220
+ """
221
+ Parameters
222
+ ----------
223
+ masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
224
+ The device can be CUDA, but `to_pil_image` takes care of that.
225
+
226
+ shape: Tuple[int, int]
227
+ (width, height) of the original image
228
+ """
229
+ n = masks.shape[0]
230
+ return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
231
+
232
+ class MergeInvert(IntEnum):
233
+ NONE = 0
234
+ MERGE = 1
235
+ MERGE_INVERT = 2
236
+
237
+ def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image:
238
+ """
239
+ The offset function takes an image and offsets it by a given x(→) and y(↑) value.
240
+
241
+ Parameters
242
+ ----------
243
+ mask: Image.Image
244
+ Pass the mask image to the function
245
+ x: int
246
+
247
+ y: int
248
+
249
+
250
+ Returns
251
+ -------
252
+ PIL.Image.Image
253
+ A new image that is offset by x and y
254
+ """
255
+ return ImageChops.offset(img, x, -y)
256
+
257
+
258
+ def is_all_black(img: Image.Image) -> bool:
259
+ arr = np.array(img)
260
+ return cv2.countNonZero(arr) == 0
261
+
262
+ def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
263
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
264
+ return cv2.dilate(arr, kernel, iterations=1)
265
+
266
+
267
+ def _erode(arr: np.ndarray, value: int) -> np.ndarray:
268
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
269
+ return cv2.erode(arr, kernel, iterations=1)
270
+
271
+ def dilate_erode(img: Image.Image, value: int) -> Image.Image:
272
+ """
273
+ The dilate_erode function takes an image and a value.
274
+ If the value is positive, it dilates the image by that amount.
275
+ If the value is negative, it erodes the image by that amount.
276
+
277
+ Parameters
278
+ ----------
279
+ img: PIL.Image.Image
280
+ the image to be processed
281
+ value: int
282
+ kernel size of dilation or erosion
283
+
284
+ Returns
285
+ -------
286
+ PIL.Image.Image
287
+ The image that has been dilated or eroded
288
+ """
289
+ if value == 0:
290
+ return img
291
+
292
+ arr = np.array(img)
293
+ arr = _dilate(arr, value) if value > 0 else _erode(arr, -value)
294
+
295
+ return Image.fromarray(arr)
296
+
297
+ def mask_preprocess(
298
+ masks: List[Image.Image],
299
+ kernel: int = 0,
300
+ x_offset: int = 0,
301
+ y_offset: int = 0,
302
+ merge_invert: Union[int, 'MergeInvert', str] = MergeInvert.NONE,
303
+ ) -> List[Image.Image]:
304
+ """
305
+ The mask_preprocess function takes a list of masks and preprocesses them.
306
+ It dilates and erodes the masks, and offsets them by x_offset and y_offset.
307
+
308
+ Parameters
309
+ ----------
310
+ masks: List[Image.Image]
311
+ A list of masks
312
+ kernel: int
313
+ kernel size of dilation or erosion
314
+ x_offset: int
315
+
316
+ y_offset: int
317
+
318
+
319
+ Returns
320
+ -------
321
+ List[Image.Image]
322
+ A list of processed masks
323
+ """
324
+ if not masks:
325
+ return []
326
+
327
+ if x_offset != 0 or y_offset != 0:
328
+ masks = [offset(m, x_offset, y_offset) for m in masks]
329
+
330
+ if kernel != 0:
331
+ masks = [dilate_erode(m, kernel) for m in masks]
332
+ masks = [m for m in masks if not is_all_black(m)]
333
+
334
+ return mask_merge_invert(masks, mode=merge_invert)
335
+
336
+ def mask_merge_invert(
337
+ masks: List[Image.Image], mode: Union[int, 'MergeInvert', str]
338
+ ) -> List[Image.Image]:
339
+ if isinstance(mode, str):
340
+ mode = MASK_MERGE_INVERT.index(mode)
341
+
342
+ if mode == MergeInvert.NONE or not masks:
343
+ return masks
344
+
345
+ if mode == MergeInvert.MERGE:
346
+ return mask_merge(masks)
347
+
348
+ if mode == MergeInvert.MERGE_INVERT:
349
+ merged = mask_merge(masks)
350
+ return mask_invert(merged)
351
+
352
+ raise RuntimeError
353
+
354
+ def bbox_area(bbox: List[float]):
355
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
356
+
357
+ def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
358
+ def is_in_ratio(bbox: List[float], low: float, high: float, orig_area: int) -> bool:
359
+ area = bbox_area(bbox)
360
+ return low <= area / orig_area <= high
361
+
362
+ if not pred.bboxes:
363
+ return pred
364
+
365
+ w, h = pred.preview.size
366
+ orig_area = w * h
367
+ items = len(pred.bboxes)
368
+ idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
369
+ pred.bboxes = [pred.bboxes[i] for i in idx]
370
+ pred.masks = [pred.masks[i] for i in idx]
371
+ return pred
372
+
373
+ class SortBy(IntEnum):
374
+ NONE = 0
375
+ LEFT_TO_RIGHT = 1
376
+ CENTER_TO_EDGE = 2
377
+ AREA = 3
378
+
379
+ # Bbox sorting
380
+ def _key_left_to_right(bbox: List[float]) -> float:
381
+ """
382
+ Left to right
383
+
384
+ Parameters
385
+ ----------
386
+ bbox: list[float]
387
+ list of [x1, y1, x2, y2]
388
+ """
389
+ return bbox[0]
390
+
391
+
392
+ def _key_center_to_edge(bbox: List[float], *, center: Tuple[float, float]) -> float:
393
+ """
394
+ Center to edge
395
+
396
+ Parameters
397
+ ----------
398
+ bbox: list[float]
399
+ list of [x1, y1, x2, y2]
400
+ image: Image.Image
401
+ the image
402
+ """
403
+ bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
404
+ return dist(center, bbox_center)
405
+
406
+
407
+ def _key_area(bbox: List[float]) -> float:
408
+ """
409
+ Large to small
410
+
411
+ Parameters
412
+ ----------
413
+ bbox: list[float]
414
+ list of [x1, y1, x2, y2]
415
+ """
416
+ return -bbox_area(bbox)
417
+
418
+ def sort_bboxes(
419
+ pred: PredictOutput, order: Union[int, 'SortBy'] = SortBy.NONE
420
+ ) -> PredictOutput:
421
+ if order == SortBy.NONE or len(pred.bboxes) <= 1:
422
+ return pred
423
+
424
+ if order == SortBy.LEFT_TO_RIGHT:
425
+ key = _key_left_to_right
426
+ elif order == SortBy.CENTER_TO_EDGE:
427
+ width, height = pred.preview.size
428
+ center = (width / 2, height / 2)
429
+ key = partial(_key_center_to_edge, center=center)
430
+ elif order == SortBy.AREA:
431
+ key = _key_area
432
+ else:
433
+ raise RuntimeError
434
+
435
+ items = len(pred.bboxes)
436
+ idx = sorted(range(items), key=lambda i: key(pred.bboxes[i]))
437
+ pred.bboxes = [pred.bboxes[i] for i in idx]
438
+ pred.masks = [pred.masks[i] for i in idx]
439
+ return pred
440
+
441
+ def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput:
442
+ if not pred.bboxes or k == 0:
443
+ return pred
444
+ areas = [bbox_area(bbox) for bbox in pred.bboxes]
445
+ idx = np.argsort(areas)[-k:]
446
+ pred.bboxes = [pred.bboxes[i] for i in idx]
447
+ pred.masks = [pred.masks[i] for i in idx]
448
+ return pred
449
+
450
+ def pred_preprocessing(pred: PredictOutput) -> List[Image.Image]:
451
+ pred = filter_by_ratio(
452
+ pred, low=0.0, high=1.0
453
+ )
454
+ pred = filter_k_largest(pred, k=0)
455
+ pred = sort_bboxes(pred, SortBy.AREA)
456
+ return mask_preprocess(
457
+ pred.masks,
458
+ kernel=4,
459
+ x_offset=0,
460
+ y_offset=0,
461
+ merge_invert="None",
462
+ )
463
+
464
+ def apply_overlay(image, paste_loc, overlay):
465
+ if overlay is None:
466
+ return image
467
+
468
+ if paste_loc is not None:
469
+ x, y, w, h = paste_loc
470
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
471
+ image = image.resize((w, h), Image.LANCZOS)
472
+ base_image.paste(image, (x, y))
473
+ image = base_image
474
+
475
+ image = image.convert('RGBA')
476
+ image.alpha_composite(overlay)
477
+ image = image.convert('RGB')
478
+
479
+ return image
post_process/inswapper/CodeFormer/CodeFormer/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/CodeFormer_logo.png" height=110>
3
+ </p>
4
+
5
+ ## Towards Robust Blind Face Restoration with Codebook Lookup Transformer
6
+
7
+ [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
8
+
9
+
10
+ <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
11
+
12
+ [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
13
+
14
+ S-Lab, Nanyang Technological University
15
+
16
+ <img src="assets/network.jpg" width="800px"/>
17
+
18
+
19
+ :star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
20
+
21
+ ### Update
22
+
23
+ - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
24
+ - **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
25
+ - **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
26
+ - **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
27
+ - **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
28
+ - **2022.07.17**: Add Colab demo of CodeFormer. <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
29
+ - **2022.07.16**: Release inference code for face restoration. :blush:
30
+ - **2022.06.21**: This repo is created.
31
+
32
+ ### TODO
33
+ - [ ] Add checkpoint for face inpainting
34
+ - [ ] Add training code and config files
35
+ - [x] ~~Add background image enhancement~~
36
+
37
+ #### Face Restoration
38
+
39
+ <img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
40
+ <img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>
41
+
42
+ #### Face Color Enhancement and Restoration
43
+
44
+ <img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>
45
+
46
+ #### Face Inpainting
47
+
48
+ <img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>
49
+
50
+
51
+
52
+ ### Dependencies and Installation
53
+
54
+ - Pytorch >= 1.7.1
55
+ - CUDA >= 10.1
56
+ - Other required packages in `requirements.txt`
57
+ ```
58
+ # git clone this repository
59
+ git clone https://github.com/sczhou/CodeFormer
60
+ cd CodeFormer
61
+
62
+ # create new anaconda env
63
+ conda create -n codeformer python=3.8 -y
64
+ conda activate codeformer
65
+
66
+ # install python dependencies
67
+ pip3 install -r requirements.txt
68
+ python basicsr/setup.py develop
69
+ ```
70
+ <!-- conda install -c conda-forge dlib -->
71
+
72
+ ### Quick Inference
73
+
74
+ ##### Download Pre-trained Models:
75
+ Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by runing the following command.
76
+ ```
77
+ python scripts/download_pretrained_models.py facelib
78
+ ```
79
+
80
+ Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by runing the following command.
81
+ ```
82
+ python scripts/download_pretrained_models.py CodeFormer
83
+ ```
84
+
85
+ ##### Prepare Testing Data:
86
+ You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder.
87
+
88
+
89
+ ##### Testing on Face Restoration:
90
+ ```
91
+ # For cropped and aligned faces
92
+ python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
93
+
94
+ # For the whole images
95
+ # Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
96
+ # Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
97
+ python inference_codeformer.py --w 0.7 --test_path [input folder]
98
+ ```
99
+
100
+ NOTE that *w* is in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result.
101
+
102
+ The results will be saved in the `results` folder.
103
+
104
+ ### Citation
105
+ If our work is useful for your research, please consider citing:
106
+
107
+ @article{zhou2022codeformer,
108
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
109
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
110
+ journal = {arXiv preprint arXiv:2206.11253},
111
+ year = {2022}
112
+ }
113
+
114
+ ### License
115
+
116
+ <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
117
+
118
+ ### Acknowledgement
119
+
120
+ This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). We also borrow some codes from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). Thanks for their awesome works.
121
+
122
+ ### Contact
123
+ If you have any question, please feel free to reach me out at `[email protected]`.
post_process/inswapper/CodeFormer/CodeFormer/assets/CodeFormer_logo.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/color_enhancement_result1.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/color_enhancement_result2.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/inpainting_result1.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/inpainting_result2.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/network.jpg ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result1.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result2.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result3.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/assets/restoration_result4.png ADDED
post_process/inswapper/CodeFormer/CodeFormer/basicsr/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 1.3.2
post_process/inswapper/CodeFormer/CodeFormer/basicsr/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .train import *
10
+ from .utils import *
11
+ from .version import __gitsha__, __version__
post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (371 Bytes). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/train.cpython-310.pyc ADDED
Binary file (6.34 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/__pycache__/version.cpython-310.pyc ADDED
Binary file (249 Bytes). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/arcface_arch.cpython-310.pyc ADDED
Binary file (7.38 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/arch_util.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/codeformer_arch.cpython-310.pyc ADDED
Binary file (9.22 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc ADDED
Binary file (4.46 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc ADDED
Binary file (4.86 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/__pycache__/vqgan_arch.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/arcface_arch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from basicsr.utils.registry import ARCH_REGISTRY
3
+
4
+
5
+ def conv3x3(inplanes, outplanes, stride=1):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
+
15
+
16
+ class BasicBlock(nn.Module):
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class IRBlock(nn.Module):
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
67
+
68
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
+ super(IRBlock, self).__init__()
70
+ self.bn0 = nn.BatchNorm2d(inplanes)
71
+ self.conv1 = conv3x3(inplanes, inplanes)
72
+ self.bn1 = nn.BatchNorm2d(inplanes)
73
+ self.prelu = nn.PReLU()
74
+ self.conv2 = conv3x3(inplanes, planes, stride)
75
+ self.bn2 = nn.BatchNorm2d(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+ self.use_se = use_se
79
+ if self.use_se:
80
+ self.se = SEBlock(planes)
81
+
82
+ def forward(self, x):
83
+ residual = x
84
+ out = self.bn0(x)
85
+ out = self.conv1(out)
86
+ out = self.bn1(out)
87
+ out = self.prelu(out)
88
+
89
+ out = self.conv2(out)
90
+ out = self.bn2(out)
91
+ if self.use_se:
92
+ out = self.se(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.prelu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
113
+
114
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
115
+ super(Bottleneck, self).__init__()
116
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
+ self.bn1 = nn.BatchNorm2d(planes)
118
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(planes)
120
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
+ self.relu = nn.ReLU(inplace=True)
123
+ self.downsample = downsample
124
+ self.stride = stride
125
+
126
+ def forward(self, x):
127
+ residual = x
128
+
129
+ out = self.conv1(x)
130
+ out = self.bn1(out)
131
+ out = self.relu(out)
132
+
133
+ out = self.conv2(out)
134
+ out = self.bn2(out)
135
+ out = self.relu(out)
136
+
137
+ out = self.conv3(out)
138
+ out = self.bn3(out)
139
+
140
+ if self.downsample is not None:
141
+ residual = self.downsample(x)
142
+
143
+ out += residual
144
+ out = self.relu(out)
145
+
146
+ return out
147
+
148
+
149
+ class SEBlock(nn.Module):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
156
+
157
+ def __init__(self, channel, reduction=16):
158
+ super(SEBlock, self).__init__()
159
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
+ self.fc = nn.Sequential(
161
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
+ nn.Sigmoid())
163
+
164
+ def forward(self, x):
165
+ b, c, _, _ = x.size()
166
+ y = self.avg_pool(x).view(b, c)
167
+ y = self.fc(y).view(b, c, 1, 1)
168
+ return x * y
169
+
170
+
171
+ @ARCH_REGISTRY.register()
172
+ class ResNetArcFace(nn.Module):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
182
+
183
+ def __init__(self, block, layers, use_se=True):
184
+ if block == 'IRBlock':
185
+ block = IRBlock
186
+ self.inplanes = 64
187
+ self.use_se = use_se
188
+ super(ResNetArcFace, self).__init__()
189
+
190
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
+ self.bn1 = nn.BatchNorm2d(64)
192
+ self.prelu = nn.PReLU()
193
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
+ self.layer1 = self._make_layer(block, 64, layers[0])
195
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
+ self.bn4 = nn.BatchNorm2d(512)
199
+ self.dropout = nn.Dropout()
200
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
+ self.bn5 = nn.BatchNorm1d(512)
202
+
203
+ # initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_normal_(m.weight)
207
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
+ nn.init.constant_(m.weight, 1)
209
+ nn.init.constant_(m.bias, 0)
210
+ elif isinstance(m, nn.Linear):
211
+ nn.init.xavier_normal_(m.weight)
212
+ nn.init.constant_(m.bias, 0)
213
+
214
+ def _make_layer(self, block, planes, num_blocks, stride=1):
215
+ downsample = None
216
+ if stride != 1 or self.inplanes != planes * block.expansion:
217
+ downsample = nn.Sequential(
218
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
+ nn.BatchNorm2d(planes * block.expansion),
220
+ )
221
+ layers = []
222
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
+ self.inplanes = planes
224
+ for _ in range(1, num_blocks):
225
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
+
227
+ return nn.Sequential(*layers)
228
+
229
+ def forward(self, x):
230
+ x = self.conv1(x)
231
+ x = self.bn1(x)
232
+ x = self.prelu(x)
233
+ x = self.maxpool(x)
234
+
235
+ x = self.layer1(x)
236
+ x = self.layer2(x)
237
+ x = self.layer3(x)
238
+ x = self.layer4(x)
239
+ x = self.bn4(x)
240
+ x = self.dropout(x)
241
+ x = x.view(x.size(0), -1)
242
+ x = self.fc5(x)
243
+ x = self.bn5(x)
244
+
245
+ return x
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ from distutils.version import LooseVersion
7
+ from itertools import repeat
8
+ from torch import nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn import init as init
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
+
13
+ from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
+ from basicsr.utils import get_root_logger
15
+
16
+
17
+ @torch.no_grad()
18
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
+ """Initialize network weights.
20
+
21
+ Args:
22
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
+ scale (float): Scale initialized weights, especially for residual
24
+ blocks. Default: 1.
25
+ bias_fill (float): The value to fill bias. Default: 0
26
+ kwargs (dict): Other arguments for initialization function.
27
+ """
28
+ if not isinstance(module_list, list):
29
+ module_list = [module_list]
30
+ for module in module_list:
31
+ for m in module.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ init.kaiming_normal_(m.weight, **kwargs)
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+ elif isinstance(m, nn.Linear):
38
+ init.kaiming_normal_(m.weight, **kwargs)
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.fill_(bias_fill)
42
+ elif isinstance(m, _BatchNorm):
43
+ init.constant_(m.weight, 1)
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+
47
+
48
+ def make_layer(basic_block, num_basic_block, **kwarg):
49
+ """Make layers by stacking the same blocks.
50
+
51
+ Args:
52
+ basic_block (nn.module): nn.module class for basic block.
53
+ num_basic_block (int): number of blocks.
54
+
55
+ Returns:
56
+ nn.Sequential: Stacked blocks in nn.Sequential.
57
+ """
58
+ layers = []
59
+ for _ in range(num_basic_block):
60
+ layers.append(basic_block(**kwarg))
61
+ return nn.Sequential(*layers)
62
+
63
+
64
+ class ResidualBlockNoBN(nn.Module):
65
+ """Residual block without BN.
66
+
67
+ It has a style of:
68
+ ---Conv-ReLU-Conv-+-
69
+ |________________|
70
+
71
+ Args:
72
+ num_feat (int): Channel number of intermediate features.
73
+ Default: 64.
74
+ res_scale (float): Residual scale. Default: 1.
75
+ pytorch_init (bool): If set to True, use pytorch default init,
76
+ otherwise, use default_init_weights. Default: False.
77
+ """
78
+
79
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
+ super(ResidualBlockNoBN, self).__init__()
81
+ self.res_scale = res_scale
82
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
+ self.relu = nn.ReLU(inplace=True)
85
+
86
+ if not pytorch_init:
87
+ default_init_weights([self.conv1, self.conv2], 0.1)
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+ out = self.conv2(self.relu(self.conv1(x)))
92
+ return identity + out * self.res_scale
93
+
94
+
95
+ class Upsample(nn.Sequential):
96
+ """Upsample module.
97
+
98
+ Args:
99
+ scale (int): Scale factor. Supported scales: 2^n and 3.
100
+ num_feat (int): Channel number of intermediate features.
101
+ """
102
+
103
+ def __init__(self, scale, num_feat):
104
+ m = []
105
+ if (scale & (scale - 1)) == 0: # scale = 2^n
106
+ for _ in range(int(math.log(scale, 2))):
107
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(2))
109
+ elif scale == 3:
110
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
+ m.append(nn.PixelShuffle(3))
112
+ else:
113
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
+ super(Upsample, self).__init__(*m)
115
+
116
+
117
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
+ """Warp an image or feature map with optical flow.
119
+
120
+ Args:
121
+ x (Tensor): Tensor with size (n, c, h, w).
122
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
+ Default: 'zeros'.
126
+ align_corners (bool): Before pytorch 1.3, the default value is
127
+ align_corners=True. After pytorch 1.3, the default value is
128
+ align_corners=False. Here, we use the True as default.
129
+
130
+ Returns:
131
+ Tensor: Warped image or feature map.
132
+ """
133
+ assert x.size()[-2:] == flow.size()[1:3]
134
+ _, _, h, w = x.size()
135
+ # create mesh grid
136
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
+ grid.requires_grad = False
139
+
140
+ vgrid = grid + flow
141
+ # scale grid to [-1,1]
142
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
+
147
+ # TODO, what if align_corners=False
148
+ return output
149
+
150
+
151
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
+ """Resize a flow according to ratio or shape.
153
+
154
+ Args:
155
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
+ size_type (str): 'ratio' or 'shape'.
157
+ sizes (list[int | float]): the ratio for resizing or the final output
158
+ shape.
159
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
160
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
+ ratio > 1.0).
163
+ 2) The order of output_size should be [out_h, out_w].
164
+ interp_mode (str): The mode of interpolation for resizing.
165
+ Default: 'bilinear'.
166
+ align_corners (bool): Whether align corners. Default: False.
167
+
168
+ Returns:
169
+ Tensor: Resized flow.
170
+ """
171
+ _, _, flow_h, flow_w = flow.size()
172
+ if size_type == 'ratio':
173
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
+ elif size_type == 'shape':
175
+ output_h, output_w = sizes[0], sizes[1]
176
+ else:
177
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
+
179
+ input_flow = flow.clone()
180
+ ratio_h = output_h / flow_h
181
+ ratio_w = output_w / flow_w
182
+ input_flow[:, 0, :, :] *= ratio_w
183
+ input_flow[:, 1, :, :] *= ratio_h
184
+ resized_flow = F.interpolate(
185
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
+ return resized_flow
187
+
188
+
189
+ # TODO: may write a cpp file
190
+ def pixel_unshuffle(x, scale):
191
+ """ Pixel unshuffle.
192
+
193
+ Args:
194
+ x (Tensor): Input feature with shape (b, c, hh, hw).
195
+ scale (int): Downsample ratio.
196
+
197
+ Returns:
198
+ Tensor: the pixel unshuffled feature.
199
+ """
200
+ b, c, hh, hw = x.size()
201
+ out_channel = c * (scale**2)
202
+ assert hh % scale == 0 and hw % scale == 0
203
+ h = hh // scale
204
+ w = hw // scale
205
+ x_view = x.view(b, c, h, scale, w, scale)
206
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
+
208
+
209
+ class DCNv2Pack(ModulatedDeformConvPack):
210
+ """Modulated deformable conv for deformable alignment.
211
+
212
+ Different from the official DCNv2Pack, which generates offsets and masks
213
+ from the preceding features, this DCNv2Pack takes another different
214
+ features to generate offsets and masks.
215
+
216
+ Ref:
217
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
218
+ """
219
+
220
+ def forward(self, x, feat):
221
+ out = self.conv_offset(feat)
222
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
223
+ offset = torch.cat((o1, o2), dim=1)
224
+ mask = torch.sigmoid(mask)
225
+
226
+ offset_absmean = torch.mean(torch.abs(offset))
227
+ if offset_absmean > 50:
228
+ logger = get_root_logger()
229
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
+
231
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
+ self.dilation, mask)
234
+ else:
235
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
+ self.dilation, self.groups, self.deformable_groups)
237
+
238
+
239
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
+ def norm_cdf(x):
244
+ # Computes standard normal cumulative distribution function
245
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
+
247
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
248
+ warnings.warn(
249
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
+ 'The distribution of values may be incorrect.',
251
+ stacklevel=2)
252
+
253
+ with torch.no_grad():
254
+ # Values are generated by using a truncated uniform distribution and
255
+ # then using the inverse CDF for the normal distribution.
256
+ # Get upper and lower cdf values
257
+ low = norm_cdf((a - mean) / std)
258
+ up = norm_cdf((b - mean) / std)
259
+
260
+ # Uniformly fill tensor with values from [low, up], then translate to
261
+ # [2l-1, 2u-1].
262
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
263
+
264
+ # Use inverse cdf transform for normal distribution to get truncated
265
+ # standard normal
266
+ tensor.erfinv_()
267
+
268
+ # Transform to proper mean, std
269
+ tensor.mul_(std * math.sqrt(2.))
270
+ tensor.add_(mean)
271
+
272
+ # Clamp to ensure it's in the proper range
273
+ tensor.clamp_(min=a, max=b)
274
+ return tensor
275
+
276
+
277
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
+ r"""Fills the input Tensor with values drawn from a truncated
279
+ normal distribution.
280
+
281
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
+
283
+ The values are effectively drawn from the
284
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
+ with values outside :math:`[a, b]` redrawn until they are within
286
+ the bounds. The method used for generating the random values works
287
+ best when :math:`a \leq \text{mean} \leq b`.
288
+
289
+ Args:
290
+ tensor: an n-dimensional `torch.Tensor`
291
+ mean: the mean of the normal distribution
292
+ std: the standard deviation of the normal distribution
293
+ a: the minimum cutoff value
294
+ b: the maximum cutoff value
295
+
296
+ Examples:
297
+ >>> w = torch.empty(3, 5)
298
+ >>> nn.init.trunc_normal_(w)
299
+ """
300
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
+
302
+
303
+ # From PyTorch
304
+ def _ntuple(n):
305
+
306
+ def parse(x):
307
+ if isinstance(x, collections.abc.Iterable):
308
+ return x
309
+ return tuple(repeat(x, n))
310
+
311
+ return parse
312
+
313
+
314
+ to_1tuple = _ntuple(1)
315
+ to_2tuple = _ntuple(2)
316
+ to_3tuple = _ntuple(3)
317
+ to_4tuple = _ntuple(4)
318
+ to_ntuple = _ntuple
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/codeformer_arch.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from basicsr.archs.vqgan_arch import *
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=['32', '64', '128', '256'],
165
+ fix_modules=['quantize','generator']):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if fix_modules is not None:
169
+ for module in fix_modules:
170
+ for param in getattr(self, module).parameters():
171
+ param.requires_grad = False
172
+
173
+ self.connect_list = connect_list
174
+ self.n_layers = n_layers
175
+ self.dim_embd = dim_embd
176
+ self.dim_mlp = dim_embd*2
177
+
178
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
179
+ self.feat_emb = nn.Linear(256, self.dim_embd)
180
+
181
+ # transformer
182
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
183
+ for _ in range(self.n_layers)])
184
+
185
+ # logits_predict head
186
+ self.idx_pred_layer = nn.Sequential(
187
+ nn.LayerNorm(dim_embd),
188
+ nn.Linear(dim_embd, codebook_size, bias=False))
189
+
190
+ self.channels = {
191
+ '16': 512,
192
+ '32': 256,
193
+ '64': 256,
194
+ '128': 128,
195
+ '256': 128,
196
+ '512': 64,
197
+ }
198
+
199
+ # after second residual block for > 16, before attn layer for ==16
200
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
201
+ # after first residual block for > 16, before attn layer for ==16
202
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
203
+
204
+ # fuse_convs_dict
205
+ self.fuse_convs_dict = nn.ModuleDict()
206
+ for f_size in self.connect_list:
207
+ in_ch = self.channels[f_size]
208
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
209
+
210
+ def _init_weights(self, module):
211
+ if isinstance(module, (nn.Linear, nn.Embedding)):
212
+ module.weight.data.normal_(mean=0.0, std=0.02)
213
+ if isinstance(module, nn.Linear) and module.bias is not None:
214
+ module.bias.data.zero_()
215
+ elif isinstance(module, nn.LayerNorm):
216
+ module.bias.data.zero_()
217
+ module.weight.data.fill_(1.0)
218
+
219
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
220
+ # ################### Encoder #####################
221
+ enc_feat_dict = {}
222
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
223
+ for i, block in enumerate(self.encoder.blocks):
224
+ x = block(x)
225
+ if i in out_list:
226
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
227
+
228
+ lq_feat = x
229
+ # ################# Transformer ###################
230
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
231
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
232
+ # BCHW -> BC(HW) -> (HW)BC
233
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
234
+ query_emb = feat_emb
235
+ # Transformer encoder
236
+ for layer in self.ft_layers:
237
+ query_emb = layer(query_emb, query_pos=pos_emb)
238
+
239
+ # output logits
240
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
241
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
242
+
243
+ if code_only: # for training stage II
244
+ # logits doesn't need softmax before cross_entropy loss
245
+ return logits, lq_feat
246
+
247
+ # ################# Quantization ###################
248
+ # if self.training:
249
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
250
+ # # b(hw)c -> bc(hw) -> bchw
251
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
252
+ # ------------
253
+ soft_one_hot = F.softmax(logits, dim=2)
254
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
255
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
256
+ # preserve gradients
257
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
258
+
259
+ if detach_16:
260
+ quant_feat = quant_feat.detach() # for training stage III
261
+ if adain:
262
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
263
+
264
+ # ################## Generator ####################
265
+ x = quant_feat
266
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
267
+
268
+ for i, block in enumerate(self.generator.blocks):
269
+ x = block(x)
270
+ if i in fuse_list: # fuse after i-th block
271
+ f_size = str(x.shape[-1])
272
+ if w>0:
273
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
274
+ out = x
275
+ # logits doesn't need softmax before cross_entropy loss
276
+ return out, logits, lq_feat
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/rrdbnet_arch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
+
8
+
9
+ class ResidualDenseBlock(nn.Module):
10
+ """Residual Dense Block.
11
+
12
+ Used in RRDB block in ESRGAN.
13
+
14
+ Args:
15
+ num_feat (int): Channel number of intermediate features.
16
+ num_grow_ch (int): Channels for each growth.
17
+ """
18
+
19
+ def __init__(self, num_feat=64, num_grow_ch=32):
20
+ super(ResidualDenseBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
+
27
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+
29
+ # initialization
30
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
+
32
+ def forward(self, x):
33
+ x1 = self.lrelu(self.conv1(x))
34
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
+ # Emperically, we use 0.2 to scale the residual for better performance
39
+ return x5 * 0.2 + x
40
+
41
+
42
+ class RRDB(nn.Module):
43
+ """Residual in Residual Dense Block.
44
+
45
+ Used in RRDB-Net in ESRGAN.
46
+
47
+ Args:
48
+ num_feat (int): Channel number of intermediate features.
49
+ num_grow_ch (int): Channels for each growth.
50
+ """
51
+
52
+ def __init__(self, num_feat, num_grow_ch=32):
53
+ super(RRDB, self).__init__()
54
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
+
58
+ def forward(self, x):
59
+ out = self.rdb1(x)
60
+ out = self.rdb2(out)
61
+ out = self.rdb3(out)
62
+ # Emperically, we use 0.2 to scale the residual for better performance
63
+ return out * 0.2 + x
64
+
65
+
66
+ @ARCH_REGISTRY.register()
67
+ class RRDBNet(nn.Module):
68
+ """Networks consisting of Residual in Residual Dense Block, which is used
69
+ in ESRGAN.
70
+
71
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
+
73
+ We extend ESRGAN for scale x2 and scale x1.
74
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
75
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
+
78
+ Args:
79
+ num_in_ch (int): Channel number of inputs.
80
+ num_out_ch (int): Channel number of outputs.
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64
83
+ num_block (int): Block number in the trunk network. Defaults: 23
84
+ num_grow_ch (int): Channels for each growth. Default: 32.
85
+ """
86
+
87
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
+ super(RRDBNet, self).__init__()
89
+ self.scale = scale
90
+ if scale == 2:
91
+ num_in_ch = num_in_ch * 4
92
+ elif scale == 1:
93
+ num_in_ch = num_in_ch * 16
94
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ # upsample
98
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
+ return out
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+ output = {}
155
+
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
post_process/inswapper/CodeFormer/CodeFormer/basicsr/archs/vqgan_arch.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import copy
11
+ from basicsr.utils import get_root_logger
12
+ from basicsr.utils.registry import ARCH_REGISTRY
13
+
14
+ def normalize(in_channels):
15
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
+
17
+
18
+ @torch.jit.script
19
+ def swish(x):
20
+ return x*torch.sigmoid(x)
21
+
22
+
23
+ # Define VQVAE classes
24
+ class VectorQuantizer(nn.Module):
25
+ def __init__(self, codebook_size, emb_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.codebook_size = codebook_size # number of embeddings
28
+ self.emb_dim = emb_dim # dimension of embedding
29
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
+
33
+ def forward(self, z):
34
+ # reshape z -> (batch, height, width, channel) and flatten
35
+ z = z.permute(0, 2, 3, 1).contiguous()
36
+ z_flattened = z.view(-1, self.emb_dim)
37
+
38
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
+
42
+ mean_distance = torch.mean(d)
43
+ # find closest encodings
44
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
+ # [0-1], higher score, higher confidence
47
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
+
49
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
+ min_encodings.scatter_(1, min_encoding_indices, 1)
51
+
52
+ # get quantized latent vectors
53
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
+ # compute loss for embedding
55
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
+ # preserve gradients
57
+ z_q = z + (z_q - z).detach()
58
+
59
+ # perplexity
60
+ e_mean = torch.mean(min_encodings, dim=0)
61
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
+ # reshape back to match original input shape
63
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
+
65
+ return z_q, loss, {
66
+ "perplexity": perplexity,
67
+ "min_encodings": min_encodings,
68
+ "min_encoding_indices": min_encoding_indices,
69
+ "min_encoding_scores": min_encoding_scores,
70
+ "mean_distance": mean_distance
71
+ }
72
+
73
+ def get_codebook_feat(self, indices, shape):
74
+ # input indices: batch*token_num -> (batch*token_num)*1
75
+ # shape: batch, height, width, channel
76
+ indices = indices.view(-1,1)
77
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
78
+ min_encodings.scatter_(1, indices, 1)
79
+ # get quantized latent vectors
80
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
81
+
82
+ if shape is not None: # reshape back to match original input shape
83
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
84
+
85
+ return z_q
86
+
87
+
88
+ class GumbelQuantizer(nn.Module):
89
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
90
+ super().__init__()
91
+ self.codebook_size = codebook_size # number of embeddings
92
+ self.emb_dim = emb_dim # dimension of embedding
93
+ self.straight_through = straight_through
94
+ self.temperature = temp_init
95
+ self.kl_weight = kl_weight
96
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
97
+ self.embed = nn.Embedding(codebook_size, emb_dim)
98
+
99
+ def forward(self, z):
100
+ hard = self.straight_through if self.training else True
101
+
102
+ logits = self.proj(z)
103
+
104
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
105
+
106
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
107
+
108
+ # + kl divergence to the prior loss
109
+ qy = F.softmax(logits, dim=1)
110
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
111
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
112
+
113
+ return z_q, diff, {
114
+ "min_encoding_indices": min_encoding_indices
115
+ }
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ def __init__(self, in_channels):
120
+ super().__init__()
121
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
122
+
123
+ def forward(self, x):
124
+ pad = (0, 1, 0, 1)
125
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
126
+ x = self.conv(x)
127
+ return x
128
+
129
+
130
+ class Upsample(nn.Module):
131
+ def __init__(self, in_channels):
132
+ super().__init__()
133
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
134
+
135
+ def forward(self, x):
136
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
137
+ x = self.conv(x)
138
+
139
+ return x
140
+
141
+
142
+ class ResBlock(nn.Module):
143
+ def __init__(self, in_channels, out_channels=None):
144
+ super(ResBlock, self).__init__()
145
+ self.in_channels = in_channels
146
+ self.out_channels = in_channels if out_channels is None else out_channels
147
+ self.norm1 = normalize(in_channels)
148
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+ self.norm2 = normalize(out_channels)
150
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
151
+ if self.in_channels != self.out_channels:
152
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
153
+
154
+ def forward(self, x_in):
155
+ x = x_in
156
+ x = self.norm1(x)
157
+ x = swish(x)
158
+ x = self.conv1(x)
159
+ x = self.norm2(x)
160
+ x = swish(x)
161
+ x = self.conv2(x)
162
+ if self.in_channels != self.out_channels:
163
+ x_in = self.conv_out(x_in)
164
+
165
+ return x + x_in
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels,
176
+ in_channels,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0
180
+ )
181
+ self.k = torch.nn.Conv2d(
182
+ in_channels,
183
+ in_channels,
184
+ kernel_size=1,
185
+ stride=1,
186
+ padding=0
187
+ )
188
+ self.v = torch.nn.Conv2d(
189
+ in_channels,
190
+ in_channels,
191
+ kernel_size=1,
192
+ stride=1,
193
+ padding=0
194
+ )
195
+ self.proj_out = torch.nn.Conv2d(
196
+ in_channels,
197
+ in_channels,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ h_ = x
205
+ h_ = self.norm(h_)
206
+ q = self.q(h_)
207
+ k = self.k(h_)
208
+ v = self.v(h_)
209
+
210
+ # compute attention
211
+ b, c, h, w = q.shape
212
+ q = q.reshape(b, c, h*w)
213
+ q = q.permute(0, 2, 1)
214
+ k = k.reshape(b, c, h*w)
215
+ w_ = torch.bmm(q, k)
216
+ w_ = w_ * (int(c)**(-0.5))
217
+ w_ = F.softmax(w_, dim=2)
218
+
219
+ # attend to values
220
+ v = v.reshape(b, c, h*w)
221
+ w_ = w_.permute(0, 2, 1)
222
+ h_ = torch.bmm(v, w_)
223
+ h_ = h_.reshape(b, c, h, w)
224
+
225
+ h_ = self.proj_out(h_)
226
+
227
+ return x+h_
228
+
229
+
230
+ class Encoder(nn.Module):
231
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
232
+ super().__init__()
233
+ self.nf = nf
234
+ self.num_resolutions = len(ch_mult)
235
+ self.num_res_blocks = num_res_blocks
236
+ self.resolution = resolution
237
+ self.attn_resolutions = attn_resolutions
238
+
239
+ curr_res = self.resolution
240
+ in_ch_mult = (1,)+tuple(ch_mult)
241
+
242
+ blocks = []
243
+ # initial convultion
244
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
245
+
246
+ # residual and downsampling blocks, with attention on smaller res (16x16)
247
+ for i in range(self.num_resolutions):
248
+ block_in_ch = nf * in_ch_mult[i]
249
+ block_out_ch = nf * ch_mult[i]
250
+ for _ in range(self.num_res_blocks):
251
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
252
+ block_in_ch = block_out_ch
253
+ if curr_res in attn_resolutions:
254
+ blocks.append(AttnBlock(block_in_ch))
255
+
256
+ if i != self.num_resolutions - 1:
257
+ blocks.append(Downsample(block_in_ch))
258
+ curr_res = curr_res // 2
259
+
260
+ # non-local attention block
261
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
262
+ blocks.append(AttnBlock(block_in_ch))
263
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
264
+
265
+ # normalise and convert to latent size
266
+ blocks.append(normalize(block_in_ch))
267
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
268
+ self.blocks = nn.ModuleList(blocks)
269
+
270
+ def forward(self, x):
271
+ for block in self.blocks:
272
+ x = block(x)
273
+
274
+ return x
275
+
276
+
277
+ class Generator(nn.Module):
278
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
279
+ super().__init__()
280
+ self.nf = nf
281
+ self.ch_mult = ch_mult
282
+ self.num_resolutions = len(self.ch_mult)
283
+ self.num_res_blocks = res_blocks
284
+ self.resolution = img_size
285
+ self.attn_resolutions = attn_resolutions
286
+ self.in_channels = emb_dim
287
+ self.out_channels = 3
288
+ block_in_ch = self.nf * self.ch_mult[-1]
289
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
290
+
291
+ blocks = []
292
+ # initial conv
293
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
294
+
295
+ # non-local attention block
296
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
297
+ blocks.append(AttnBlock(block_in_ch))
298
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
299
+
300
+ for i in reversed(range(self.num_resolutions)):
301
+ block_out_ch = self.nf * self.ch_mult[i]
302
+
303
+ for _ in range(self.num_res_blocks):
304
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
305
+ block_in_ch = block_out_ch
306
+
307
+ if curr_res in self.attn_resolutions:
308
+ blocks.append(AttnBlock(block_in_ch))
309
+
310
+ if i != 0:
311
+ blocks.append(Upsample(block_in_ch))
312
+ curr_res = curr_res * 2
313
+
314
+ blocks.append(normalize(block_in_ch))
315
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
316
+
317
+ self.blocks = nn.ModuleList(blocks)
318
+
319
+
320
+ def forward(self, x):
321
+ for block in self.blocks:
322
+ x = block(x)
323
+
324
+ return x
325
+
326
+
327
+ @ARCH_REGISTRY.register()
328
+ class VQAutoEncoder(nn.Module):
329
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
330
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
331
+ super().__init__()
332
+ logger = get_root_logger()
333
+ self.in_channels = 3
334
+ self.nf = nf
335
+ self.n_blocks = res_blocks
336
+ self.codebook_size = codebook_size
337
+ self.embed_dim = emb_dim
338
+ self.ch_mult = ch_mult
339
+ self.resolution = img_size
340
+ self.attn_resolutions = attn_resolutions
341
+ self.quantizer_type = quantizer
342
+ self.encoder = Encoder(
343
+ self.in_channels,
344
+ self.nf,
345
+ self.embed_dim,
346
+ self.ch_mult,
347
+ self.n_blocks,
348
+ self.resolution,
349
+ self.attn_resolutions
350
+ )
351
+ if self.quantizer_type == "nearest":
352
+ self.beta = beta #0.25
353
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
354
+ elif self.quantizer_type == "gumbel":
355
+ self.gumbel_num_hiddens = emb_dim
356
+ self.straight_through = gumbel_straight_through
357
+ self.kl_weight = gumbel_kl_weight
358
+ self.quantize = GumbelQuantizer(
359
+ self.codebook_size,
360
+ self.embed_dim,
361
+ self.gumbel_num_hiddens,
362
+ self.straight_through,
363
+ self.kl_weight
364
+ )
365
+ self.generator = Generator(
366
+ self.nf,
367
+ self.embed_dim,
368
+ self.ch_mult,
369
+ self.n_blocks,
370
+ self.resolution,
371
+ self.attn_resolutions
372
+ )
373
+
374
+ if model_path is not None:
375
+ chkpt = torch.load(model_path, map_location='cpu')
376
+ if 'params_ema' in chkpt:
377
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
378
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
379
+ elif 'params' in chkpt:
380
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
381
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
382
+ else:
383
+ raise ValueError(f'Wrong params!')
384
+
385
+
386
+ def forward(self, x):
387
+ x = self.encoder(x)
388
+ quant, codebook_loss, quant_stats = self.quantize(x)
389
+ x = self.generator(quant)
390
+ return x, codebook_loss, quant_stats
391
+
392
+
393
+
394
+ # patch based discriminator
395
+ @ARCH_REGISTRY.register()
396
+ class VQGANDiscriminator(nn.Module):
397
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
398
+ super().__init__()
399
+
400
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
401
+ ndf_mult = 1
402
+ ndf_mult_prev = 1
403
+ for n in range(1, n_layers): # gradually increase the number of filters
404
+ ndf_mult_prev = ndf_mult
405
+ ndf_mult = min(2 ** n, 8)
406
+ layers += [
407
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
408
+ nn.BatchNorm2d(ndf * ndf_mult),
409
+ nn.LeakyReLU(0.2, True)
410
+ ]
411
+
412
+ ndf_mult_prev = ndf_mult
413
+ ndf_mult = min(2 ** n_layers, 8)
414
+
415
+ layers += [
416
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
417
+ nn.BatchNorm2d(ndf * ndf_mult),
418
+ nn.LeakyReLU(0.2, True)
419
+ ]
420
+
421
+ layers += [
422
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
423
+ self.main = nn.Sequential(*layers)
424
+
425
+ if model_path is not None:
426
+ chkpt = torch.load(model_path, map_location='cpu')
427
+ if 'params_d' in chkpt:
428
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
429
+ elif 'params' in chkpt:
430
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
431
+ else:
432
+ raise ValueError(f'Wrong params!')
433
+
434
+ def forward(self, x):
435
+ return self.main(x)
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__init__.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must constain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+
84
+ prefetch_mode = dataset_opt.get('prefetch_mode')
85
+ if prefetch_mode == 'cpu': # CPUPrefetcher
86
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
87
+ logger = get_root_logger()
88
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
89
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
90
+ else:
91
+ # prefetch_mode=None: Normal dataloader
92
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
93
+ return torch.utils.data.DataLoader(**dataloader_args)
94
+
95
+
96
+ def worker_init_fn(worker_id, num_workers, rank, seed):
97
+ # Set the worker seed to num_workers * rank + worker_id + seed
98
+ worker_seed = num_workers * rank + worker_id + seed
99
+ np.random.seed(worker_seed)
100
+ random.seed(worker_seed)
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.56 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/data_sampler.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/data_util.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+
20
+ Returns:
21
+ Tensor: size (t, c, h, w), RGB, [0, 1].
22
+ """
23
+ if isinstance(path, list):
24
+ img_paths = path
25
+ else:
26
+ img_paths = sorted(list(scandir(path, full_path=True)))
27
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
28
+ if require_mod_crop:
29
+ imgs = [mod_crop(img, scale) for img in imgs]
30
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
31
+ imgs = torch.stack(imgs, dim=0)
32
+ return imgs
33
+
34
+
35
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
36
+ """Generate an index list for reading `num_frames` frames from a sequence
37
+ of images.
38
+
39
+ Args:
40
+ crt_idx (int): Current center index.
41
+ max_frame_num (int): Max number of the sequence of images (from 1).
42
+ num_frames (int): Reading num_frames frames.
43
+ padding (str): Padding mode, one of
44
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
45
+ Examples: current_idx = 0, num_frames = 5
46
+ The generated frame indices under different padding mode:
47
+ replicate: [0, 0, 0, 1, 2]
48
+ reflection: [2, 1, 0, 1, 2]
49
+ reflection_circle: [4, 3, 0, 1, 2]
50
+ circle: [3, 4, 0, 1, 2]
51
+
52
+ Returns:
53
+ list[int]: A list of indices.
54
+ """
55
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
56
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
57
+
58
+ max_frame_num = max_frame_num - 1 # start from 0
59
+ num_pad = num_frames // 2
60
+
61
+ indices = []
62
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
63
+ if i < 0:
64
+ if padding == 'replicate':
65
+ pad_idx = 0
66
+ elif padding == 'reflection':
67
+ pad_idx = -i
68
+ elif padding == 'reflection_circle':
69
+ pad_idx = crt_idx + num_pad - i
70
+ else:
71
+ pad_idx = num_frames + i
72
+ elif i > max_frame_num:
73
+ if padding == 'replicate':
74
+ pad_idx = max_frame_num
75
+ elif padding == 'reflection':
76
+ pad_idx = max_frame_num * 2 - i
77
+ elif padding == 'reflection_circle':
78
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
79
+ else:
80
+ pad_idx = i - num_frames
81
+ else:
82
+ pad_idx = i
83
+ indices.append(pad_idx)
84
+ return indices
85
+
86
+
87
+ def paired_paths_from_lmdb(folders, keys):
88
+ """Generate paired paths from lmdb files.
89
+
90
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
91
+
92
+ lq.lmdb
93
+ ├── data.mdb
94
+ ├── lock.mdb
95
+ ├── meta_info.txt
96
+
97
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
98
+ https://lmdb.readthedocs.io/en/release/ for more details.
99
+
100
+ The meta_info.txt is a specified txt file to record the meta information
101
+ of our datasets. It will be automatically created when preparing
102
+ datasets by our provided dataset tools.
103
+ Each line in the txt file records
104
+ 1)image name (with extension),
105
+ 2)image shape,
106
+ 3)compression level, separated by a white space.
107
+ Example: `baboon.png (120,125,3) 1`
108
+
109
+ We use the image name without extension as the lmdb key.
110
+ Note that we use the same key for the corresponding lq and gt images.
111
+
112
+ Args:
113
+ folders (list[str]): A list of folder path. The order of list should
114
+ be [input_folder, gt_folder].
115
+ keys (list[str]): A list of keys identifying folders. The order should
116
+ be in consistent with folders, e.g., ['lq', 'gt'].
117
+ Note that this key is different from lmdb keys.
118
+
119
+ Returns:
120
+ list[str]: Returned path list.
121
+ """
122
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
123
+ f'But got {len(folders)}')
124
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
125
+ input_folder, gt_folder = folders
126
+ input_key, gt_key = keys
127
+
128
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
129
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
130
+ f'formats. But received {input_key}: {input_folder}; '
131
+ f'{gt_key}: {gt_folder}')
132
+ # ensure that the two meta_info files are the same
133
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
134
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
135
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
136
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
137
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
138
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
139
+ else:
140
+ paths = []
141
+ for lmdb_key in sorted(input_lmdb_keys):
142
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
143
+ return paths
144
+
145
+
146
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
147
+ """Generate paired paths from an meta information file.
148
+
149
+ Each line in the meta information file contains the image names and
150
+ image shape (usually for gt), separated by a white space.
151
+
152
+ Example of an meta information file:
153
+ ```
154
+ 0001_s001.png (480,480,3)
155
+ 0001_s002.png (480,480,3)
156
+ ```
157
+
158
+ Args:
159
+ folders (list[str]): A list of folder path. The order of list should
160
+ be [input_folder, gt_folder].
161
+ keys (list[str]): A list of keys identifying folders. The order should
162
+ be in consistent with folders, e.g., ['lq', 'gt'].
163
+ meta_info_file (str): Path to the meta information file.
164
+ filename_tmpl (str): Template for each filename. Note that the
165
+ template excludes the file extension. Usually the filename_tmpl is
166
+ for files in the input folder.
167
+
168
+ Returns:
169
+ list[str]: Returned path list.
170
+ """
171
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
172
+ f'But got {len(folders)}')
173
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
174
+ input_folder, gt_folder = folders
175
+ input_key, gt_key = keys
176
+
177
+ with open(meta_info_file, 'r') as fin:
178
+ gt_names = [line.split(' ')[0] for line in fin]
179
+
180
+ paths = []
181
+ for gt_name in gt_names:
182
+ basename, ext = osp.splitext(osp.basename(gt_name))
183
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
184
+ input_path = osp.join(input_folder, input_name)
185
+ gt_path = osp.join(gt_folder, gt_name)
186
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
187
+ return paths
188
+
189
+
190
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
191
+ """Generate paired paths from folders.
192
+
193
+ Args:
194
+ folders (list[str]): A list of folder path. The order of list should
195
+ be [input_folder, gt_folder].
196
+ keys (list[str]): A list of keys identifying folders. The order should
197
+ be in consistent with folders, e.g., ['lq', 'gt'].
198
+ filename_tmpl (str): Template for each filename. Note that the
199
+ template excludes the file extension. Usually the filename_tmpl is
200
+ for files in the input folder.
201
+
202
+ Returns:
203
+ list[str]: Returned path list.
204
+ """
205
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
206
+ f'But got {len(folders)}')
207
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
208
+ input_folder, gt_folder = folders
209
+ input_key, gt_key = keys
210
+
211
+ input_paths = list(scandir(input_folder))
212
+ gt_paths = list(scandir(gt_folder))
213
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
214
+ f'{len(input_paths)}, {len(gt_paths)}.')
215
+ paths = []
216
+ for gt_path in gt_paths:
217
+ basename, ext = osp.splitext(osp.basename(gt_path))
218
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
219
+ input_path = osp.join(input_folder, input_name)
220
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
221
+ gt_path = osp.join(gt_folder, gt_path)
222
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
223
+ return paths
224
+
225
+
226
+ def paths_from_folder(folder):
227
+ """Generate paths from folder.
228
+
229
+ Args:
230
+ folder (str): Folder path.
231
+
232
+ Returns:
233
+ list[str]: Returned path list.
234
+ """
235
+
236
+ paths = list(scandir(folder))
237
+ paths = [osp.join(folder, path) for path in paths]
238
+ return paths
239
+
240
+
241
+ def paths_from_lmdb(folder):
242
+ """Generate paths from lmdb.
243
+
244
+ Args:
245
+ folder (str): Folder path.
246
+
247
+ Returns:
248
+ list[str]: Returned path list.
249
+ """
250
+ if not folder.endswith('.lmdb'):
251
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
252
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
253
+ paths = [line.split('.')[0] for line in fin]
254
+ return paths
255
+
256
+
257
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
258
+ """Generate Gaussian kernel used in `duf_downsample`.
259
+
260
+ Args:
261
+ kernel_size (int): Kernel size. Default: 13.
262
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
263
+
264
+ Returns:
265
+ np.array: The Gaussian kernel.
266
+ """
267
+ from scipy.ndimage import filters as filters
268
+ kernel = np.zeros((kernel_size, kernel_size))
269
+ # set element at the middle to one, a dirac delta
270
+ kernel[kernel_size // 2, kernel_size // 2] = 1
271
+ # gaussian-smooth the dirac, resulting in a gaussian filter
272
+ return filters.gaussian_filter(kernel, sigma)
273
+
274
+
275
+ def duf_downsample(x, kernel_size=13, scale=4):
276
+ """Downsamping with Gaussian kernel used in the DUF official code.
277
+
278
+ Args:
279
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
280
+ kernel_size (int): Kernel size. Default: 13.
281
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
282
+ Default: 4.
283
+
284
+ Returns:
285
+ Tensor: DUF downsampled frames.
286
+ """
287
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
288
+
289
+ squeeze_flag = False
290
+ if x.ndim == 4:
291
+ squeeze_flag = True
292
+ x = x.unsqueeze(0)
293
+ b, t, c, h, w = x.size()
294
+ x = x.view(-1, 1, h, w)
295
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
296
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
297
+
298
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
299
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
300
+ x = F.conv2d(x, gaussian_filter, stride=scale)
301
+ x = x[:, :, 2:-2, 2:-2]
302
+ x = x.view(b, t, c, x.size(2), x.size(3))
303
+ if squeeze_flag:
304
+ x = x.squeeze(0)
305
+ return x
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
post_process/inswapper/CodeFormer/CodeFormer/basicsr/data/transforms.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+
4
+
5
+ def mod_crop(img, scale):
6
+ """Mod crop images, used during testing.
7
+
8
+ Args:
9
+ img (ndarray): Input image.
10
+ scale (int): Scale factor.
11
+
12
+ Returns:
13
+ ndarray: Result image.
14
+ """
15
+ img = img.copy()
16
+ if img.ndim in (2, 3):
17
+ h, w = img.shape[0], img.shape[1]
18
+ h_remainder, w_remainder = h % scale, w % scale
19
+ img = img[:h - h_remainder, :w - w_remainder, ...]
20
+ else:
21
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
+ return img
23
+
24
+
25
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
26
+ """Paired random crop.
27
+
28
+ It crops lists of lq and gt images with corresponding locations.
29
+
30
+ Args:
31
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
+ should have the same shape. If the input is an ndarray, it will
33
+ be transformed to a list containing itself.
34
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
+ should have the same shape. If the input is an ndarray, it will
36
+ be transformed to a list containing itself.
37
+ gt_patch_size (int): GT patch size.
38
+ scale (int): Scale factor.
39
+ gt_path (str): Path to ground-truth.
40
+
41
+ Returns:
42
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
43
+ only have one element, just return ndarray.
44
+ """
45
+
46
+ if not isinstance(img_gts, list):
47
+ img_gts = [img_gts]
48
+ if not isinstance(img_lqs, list):
49
+ img_lqs = [img_lqs]
50
+
51
+ h_lq, w_lq, _ = img_lqs[0].shape
52
+ h_gt, w_gt, _ = img_gts[0].shape
53
+ lq_patch_size = gt_patch_size // scale
54
+
55
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
58
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
59
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60
+ f'({lq_patch_size}, {lq_patch_size}). '
61
+ f'Please remove {gt_path}.')
62
+
63
+ # randomly choose top and left coordinates for lq patch
64
+ top = random.randint(0, h_lq - lq_patch_size)
65
+ left = random.randint(0, w_lq - lq_patch_size)
66
+
67
+ # crop lq patch
68
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
69
+
70
+ # crop corresponding gt patch
71
+ top_gt, left_gt = int(top * scale), int(left * scale)
72
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
73
+ if len(img_gts) == 1:
74
+ img_gts = img_gts[0]
75
+ if len(img_lqs) == 1:
76
+ img_lqs = img_lqs[0]
77
+ return img_gts, img_lqs
78
+
79
+
80
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
81
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
82
+
83
+ We use vertical flip and transpose for rotation implementation.
84
+ All the images in the list use the same augmentation.
85
+
86
+ Args:
87
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
88
+ is an ndarray, it will be transformed to a list.
89
+ hflip (bool): Horizontal flip. Default: True.
90
+ rotation (bool): Ratotation. Default: True.
91
+ flows (list[ndarray]: Flows to be augmented. If the input is an
92
+ ndarray, it will be transformed to a list.
93
+ Dimension is (h, w, 2). Default: None.
94
+ return_status (bool): Return the status of flip and rotation.
95
+ Default: False.
96
+
97
+ Returns:
98
+ list[ndarray] | ndarray: Augmented images and flows. If returned
99
+ results only have one element, just return ndarray.
100
+
101
+ """
102
+ hflip = hflip and random.random() < 0.5
103
+ vflip = rotation and random.random() < 0.5
104
+ rot90 = rotation and random.random() < 0.5
105
+
106
+ def _augment(img):
107
+ if hflip: # horizontal
108
+ cv2.flip(img, 1, img)
109
+ if vflip: # vertical
110
+ cv2.flip(img, 0, img)
111
+ if rot90:
112
+ img = img.transpose(1, 0, 2)
113
+ return img
114
+
115
+ def _augment_flow(flow):
116
+ if hflip: # horizontal
117
+ cv2.flip(flow, 1, flow)
118
+ flow[:, :, 0] *= -1
119
+ if vflip: # vertical
120
+ cv2.flip(flow, 0, flow)
121
+ flow[:, :, 1] *= -1
122
+ if rot90:
123
+ flow = flow.transpose(1, 0, 2)
124
+ flow = flow[:, :, [1, 0]]
125
+ return flow
126
+
127
+ if not isinstance(imgs, list):
128
+ imgs = [imgs]
129
+ imgs = [_augment(img) for img in imgs]
130
+ if len(imgs) == 1:
131
+ imgs = imgs[0]
132
+
133
+ if flows is not None:
134
+ if not isinstance(flows, list):
135
+ flows = [flows]
136
+ flows = [_augment_flow(flow) for flow in flows]
137
+ if len(flows) == 1:
138
+ flows = flows[0]
139
+ return imgs, flows
140
+ else:
141
+ if return_status:
142
+ return imgs, (hflip, vflip, rot90)
143
+ else:
144
+ return imgs
145
+
146
+
147
+ def img_rotate(img, angle, center=None, scale=1.0):
148
+ """Rotate image.
149
+
150
+ Args:
151
+ img (ndarray): Image to be rotated.
152
+ angle (float): Rotation angle in degrees. Positive values mean
153
+ counter-clockwise rotation.
154
+ center (tuple[int]): Rotation center. If the center is None,
155
+ initialize it as the center of the image. Default: None.
156
+ scale (float): Isotropic scale factor. Default: 1.0.
157
+ """
158
+ (h, w) = img.shape[:2]
159
+
160
+ if center is None:
161
+ center = (w // 2, h // 2)
162
+
163
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
164
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
165
+ return rotated_img
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must constain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/loss_util.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/__pycache__/losses.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
post_process/inswapper/CodeFormer/CodeFormer/basicsr/losses/losses.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import lpips
3
+ import torch
4
+ from torch import autograd as autograd
5
+ from torch import nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
9
+ from basicsr.utils.registry import LOSS_REGISTRY
10
+ from .loss_util import weighted_loss
11
+
12
+ _reduction_modes = ['none', 'mean', 'sum']
13
+
14
+
15
+ @weighted_loss
16
+ def l1_loss(pred, target):
17
+ return F.l1_loss(pred, target, reduction='none')
18
+
19
+
20
+ @weighted_loss
21
+ def mse_loss(pred, target):
22
+ return F.mse_loss(pred, target, reduction='none')
23
+
24
+
25
+ @weighted_loss
26
+ def charbonnier_loss(pred, target, eps=1e-12):
27
+ return torch.sqrt((pred - target)**2 + eps)
28
+
29
+
30
+ @LOSS_REGISTRY.register()
31
+ class L1Loss(nn.Module):
32
+ """L1 (mean absolute error, MAE) loss.
33
+
34
+ Args:
35
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
36
+ reduction (str): Specifies the reduction to apply to the output.
37
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
38
+ """
39
+
40
+ def __init__(self, loss_weight=1.0, reduction='mean'):
41
+ super(L1Loss, self).__init__()
42
+ if reduction not in ['none', 'mean', 'sum']:
43
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
44
+
45
+ self.loss_weight = loss_weight
46
+ self.reduction = reduction
47
+
48
+ def forward(self, pred, target, weight=None, **kwargs):
49
+ """
50
+ Args:
51
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
52
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
53
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
54
+ weights. Default: None.
55
+ """
56
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
57
+
58
+
59
+ @LOSS_REGISTRY.register()
60
+ class MSELoss(nn.Module):
61
+ """MSE (L2) loss.
62
+
63
+ Args:
64
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
65
+ reduction (str): Specifies the reduction to apply to the output.
66
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
67
+ """
68
+
69
+ def __init__(self, loss_weight=1.0, reduction='mean'):
70
+ super(MSELoss, self).__init__()
71
+ if reduction not in ['none', 'mean', 'sum']:
72
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
73
+
74
+ self.loss_weight = loss_weight
75
+ self.reduction = reduction
76
+
77
+ def forward(self, pred, target, weight=None, **kwargs):
78
+ """
79
+ Args:
80
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
81
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
82
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
83
+ weights. Default: None.
84
+ """
85
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
86
+
87
+
88
+ @LOSS_REGISTRY.register()
89
+ class CharbonnierLoss(nn.Module):
90
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
91
+ variant of L1Loss).
92
+
93
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
94
+ Super-Resolution".
95
+
96
+ Args:
97
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
98
+ reduction (str): Specifies the reduction to apply to the output.
99
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
100
+ eps (float): A value used to control the curvature near zero.
101
+ Default: 1e-12.
102
+ """
103
+
104
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
105
+ super(CharbonnierLoss, self).__init__()
106
+ if reduction not in ['none', 'mean', 'sum']:
107
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
108
+
109
+ self.loss_weight = loss_weight
110
+ self.reduction = reduction
111
+ self.eps = eps
112
+
113
+ def forward(self, pred, target, weight=None, **kwargs):
114
+ """
115
+ Args:
116
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
117
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
118
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
119
+ weights. Default: None.
120
+ """
121
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
122
+
123
+
124
+ @LOSS_REGISTRY.register()
125
+ class WeightedTVLoss(L1Loss):
126
+ """Weighted TV loss.
127
+
128
+ Args:
129
+ loss_weight (float): Loss weight. Default: 1.0.
130
+ """
131
+
132
+ def __init__(self, loss_weight=1.0):
133
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
134
+
135
+ def forward(self, pred, weight=None):
136
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
137
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
138
+
139
+ loss = x_diff + y_diff
140
+
141
+ return loss
142
+
143
+
144
+ @LOSS_REGISTRY.register()
145
+ class PerceptualLoss(nn.Module):
146
+ """Perceptual loss with commonly used style loss.
147
+
148
+ Args:
149
+ layer_weights (dict): The weight for each layer of vgg feature.
150
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
151
+ feature layer (before relu5_4) will be extracted with weight
152
+ 1.0 in calculting losses.
153
+ vgg_type (str): The type of vgg network used as feature extractor.
154
+ Default: 'vgg19'.
155
+ use_input_norm (bool): If True, normalize the input image in vgg.
156
+ Default: True.
157
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
158
+ Default: False.
159
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
160
+ loss will be calculated and the loss will multiplied by the
161
+ weight. Default: 1.0.
162
+ style_weight (float): If `style_weight > 0`, the style loss will be
163
+ calculated and the loss will multiplied by the weight.
164
+ Default: 0.
165
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
166
+ """
167
+
168
+ def __init__(self,
169
+ layer_weights,
170
+ vgg_type='vgg19',
171
+ use_input_norm=True,
172
+ range_norm=False,
173
+ perceptual_weight=1.0,
174
+ style_weight=0.,
175
+ criterion='l1'):
176
+ super(PerceptualLoss, self).__init__()
177
+ self.perceptual_weight = perceptual_weight
178
+ self.style_weight = style_weight
179
+ self.layer_weights = layer_weights
180
+ self.vgg = VGGFeatureExtractor(
181
+ layer_name_list=list(layer_weights.keys()),
182
+ vgg_type=vgg_type,
183
+ use_input_norm=use_input_norm,
184
+ range_norm=range_norm)
185
+
186
+ self.criterion_type = criterion
187
+ if self.criterion_type == 'l1':
188
+ self.criterion = torch.nn.L1Loss()
189
+ elif self.criterion_type == 'l2':
190
+ self.criterion = torch.nn.L2loss()
191
+ elif self.criterion_type == 'mse':
192
+ self.criterion = torch.nn.MSELoss(reduction='mean')
193
+ elif self.criterion_type == 'fro':
194
+ self.criterion = None
195
+ else:
196
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
+
198
+ def forward(self, x, gt):
199
+ """Forward function.
200
+
201
+ Args:
202
+ x (Tensor): Input tensor with shape (n, c, h, w).
203
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
+
205
+ Returns:
206
+ Tensor: Forward results.
207
+ """
208
+ # extract vgg features
209
+ x_features = self.vgg(x)
210
+ gt_features = self.vgg(gt.detach())
211
+
212
+ # calculate perceptual loss
213
+ if self.perceptual_weight > 0:
214
+ percep_loss = 0
215
+ for k in x_features.keys():
216
+ if self.criterion_type == 'fro':
217
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
+ else:
219
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
+ percep_loss *= self.perceptual_weight
221
+ else:
222
+ percep_loss = None
223
+
224
+ # calculate style loss
225
+ if self.style_weight > 0:
226
+ style_loss = 0
227
+ for k in x_features.keys():
228
+ if self.criterion_type == 'fro':
229
+ style_loss += torch.norm(
230
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
+ else:
232
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
+ gt_features[k])) * self.layer_weights[k]
234
+ style_loss *= self.style_weight
235
+ else:
236
+ style_loss = None
237
+
238
+ return percep_loss, style_loss
239
+
240
+ def _gram_mat(self, x):
241
+ """Calculate Gram matrix.
242
+
243
+ Args:
244
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
+
246
+ Returns:
247
+ torch.Tensor: Gram matrix.
248
+ """
249
+ n, c, h, w = x.size()
250
+ features = x.view(n, c, w * h)
251
+ features_t = features.transpose(1, 2)
252
+ gram = features.bmm(features_t) / (c * h * w)
253
+ return gram
254
+
255
+
256
+ @LOSS_REGISTRY.register()
257
+ class LPIPSLoss(nn.Module):
258
+ def __init__(self,
259
+ loss_weight=1.0,
260
+ use_input_norm=True,
261
+ range_norm=False,):
262
+ super(LPIPSLoss, self).__init__()
263
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
264
+ self.loss_weight = loss_weight
265
+ self.use_input_norm = use_input_norm
266
+ self.range_norm = range_norm
267
+
268
+ if self.use_input_norm:
269
+ # the mean is for image with range [0, 1]
270
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
271
+ # the std is for image with range [0, 1]
272
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
273
+
274
+ def forward(self, pred, target):
275
+ if self.range_norm:
276
+ pred = (pred + 1) / 2
277
+ target = (target + 1) / 2
278
+ if self.use_input_norm:
279
+ pred = (pred - self.mean) / self.std
280
+ target = (target - self.mean) / self.std
281
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
282
+ return self.loss_weight * lpips_loss.mean()
283
+
284
+
285
+ @LOSS_REGISTRY.register()
286
+ class GANLoss(nn.Module):
287
+ """Define GAN loss.
288
+
289
+ Args:
290
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
291
+ real_label_val (float): The value for real label. Default: 1.0.
292
+ fake_label_val (float): The value for fake label. Default: 0.0.
293
+ loss_weight (float): Loss weight. Default: 1.0.
294
+ Note that loss_weight is only for generators; and it is always 1.0
295
+ for discriminators.
296
+ """
297
+
298
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
299
+ super(GANLoss, self).__init__()
300
+ self.gan_type = gan_type
301
+ self.loss_weight = loss_weight
302
+ self.real_label_val = real_label_val
303
+ self.fake_label_val = fake_label_val
304
+
305
+ if self.gan_type == 'vanilla':
306
+ self.loss = nn.BCEWithLogitsLoss()
307
+ elif self.gan_type == 'lsgan':
308
+ self.loss = nn.MSELoss()
309
+ elif self.gan_type == 'wgan':
310
+ self.loss = self._wgan_loss
311
+ elif self.gan_type == 'wgan_softplus':
312
+ self.loss = self._wgan_softplus_loss
313
+ elif self.gan_type == 'hinge':
314
+ self.loss = nn.ReLU()
315
+ else:
316
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
317
+
318
+ def _wgan_loss(self, input, target):
319
+ """wgan loss.
320
+
321
+ Args:
322
+ input (Tensor): Input tensor.
323
+ target (bool): Target label.
324
+
325
+ Returns:
326
+ Tensor: wgan loss.
327
+ """
328
+ return -input.mean() if target else input.mean()
329
+
330
+ def _wgan_softplus_loss(self, input, target):
331
+ """wgan loss with soft plus. softplus is a smooth approximation to the
332
+ ReLU function.
333
+
334
+ In StyleGAN2, it is called:
335
+ Logistic loss for discriminator;
336
+ Non-saturating loss for generator.
337
+
338
+ Args:
339
+ input (Tensor): Input tensor.
340
+ target (bool): Target label.
341
+
342
+ Returns:
343
+ Tensor: wgan loss.
344
+ """
345
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
346
+
347
+ def get_target_label(self, input, target_is_real):
348
+ """Get target label.
349
+
350
+ Args:
351
+ input (Tensor): Input tensor.
352
+ target_is_real (bool): Whether the target is real or fake.
353
+
354
+ Returns:
355
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
356
+ return Tensor.
357
+ """
358
+
359
+ if self.gan_type in ['wgan', 'wgan_softplus']:
360
+ return target_is_real
361
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
362
+ return input.new_ones(input.size()) * target_val
363
+
364
+ def forward(self, input, target_is_real, is_disc=False):
365
+ """
366
+ Args:
367
+ input (Tensor): The input for the loss module, i.e., the network
368
+ prediction.
369
+ target_is_real (bool): Whether the targe is real or fake.
370
+ is_disc (bool): Whether the loss for discriminators or not.
371
+ Default: False.
372
+
373
+ Returns:
374
+ Tensor: GAN loss value.
375
+ """
376
+ if self.gan_type == 'hinge':
377
+ if is_disc: # for discriminators in hinge-gan
378
+ input = -input if target_is_real else input
379
+ loss = self.loss(1 + input).mean()
380
+ else: # for generators in hinge-gan
381
+ loss = -input.mean()
382
+ else: # other gan types
383
+ target_label = self.get_target_label(input, target_is_real)
384
+ loss = self.loss(input, target_label)
385
+
386
+ # loss_weight is always 1.0 for discriminators
387
+ return loss if is_disc else loss * self.loss_weight
388
+
389
+
390
+ def r1_penalty(real_pred, real_img):
391
+ """R1 regularization for discriminator. The core idea is to
392
+ penalize the gradient on real data alone: when the
393
+ generator distribution produces the true data distribution
394
+ and the discriminator is equal to 0 on the data manifold, the
395
+ gradient penalty ensures that the discriminator cannot create
396
+ a non-zero gradient orthogonal to the data manifold without
397
+ suffering a loss in the GAN game.
398
+
399
+ Ref:
400
+ Eq. 9 in Which training methods for GANs do actually converge.
401
+ """
402
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
403
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
404
+ return grad_penalty
405
+
406
+
407
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
408
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
409
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
410
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
411
+
412
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
413
+
414
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
415
+
416
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
417
+
418
+
419
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
420
+ """Calculate gradient penalty for wgan-gp.
421
+
422
+ Args:
423
+ discriminator (nn.Module): Network for the discriminator.
424
+ real_data (Tensor): Real input data.
425
+ fake_data (Tensor): Fake input data.
426
+ weight (Tensor): Weight tensor. Default: None.
427
+
428
+ Returns:
429
+ Tensor: A tensor for gradient penalty.
430
+ """
431
+
432
+ batch_size = real_data.size(0)
433
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
434
+
435
+ # interpolate between real_data and fake_data
436
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
437
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
438
+
439
+ disc_interpolates = discriminator(interpolates)
440
+ gradients = autograd.grad(
441
+ outputs=disc_interpolates,
442
+ inputs=interpolates,
443
+ grad_outputs=torch.ones_like(disc_interpolates),
444
+ create_graph=True,
445
+ retain_graph=True,
446
+ only_inputs=True)[0]
447
+
448
+ if weight is not None:
449
+ gradients = gradients * weight
450
+
451
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
452
+ if weight is not None:
453
+ gradients_penalty /= torch.mean(weight)
454
+
455
+ return gradients_penalty