BlackBeenie commited on
Commit
fe3fdf0
·
1 Parent(s): 8dd41a8

feat: Add plugins

Browse files
mypy.ini ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [mypy]
2
+ check_untyped_defs = True
3
+ disallow_any_generics = True
4
+ disallow_untyped_calls = True
5
+ disallow_untyped_defs = True
6
+ ignore_missing_imports = True
7
+ strict_optional = False
plugin_options/core.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "default_chain": "faceswap",
3
+ "init_on_start": "faceswap,dmdnet,gfpgan,codeformer",
4
+ "is_demo_row_render": false,
5
+ "v": "2.0"
6
+ }
plugin_options/core_video.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "v": "2.0",
3
+ "video_save_codec": "libx264",
4
+ "video_save_crf": 14
5
+ }
plugin_options/plugin_codeformer.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "background_enhance": false,
3
+ "codeformer_fidelity": 0.8,
4
+ "face_upsample": true,
5
+ "skip_if_no_face": true,
6
+ "upscale": 1,
7
+ "v": "3.0"
8
+ }
plugin_options/plugin_dmdnet.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "v": "1.0"
3
+ }
plugin_options/plugin_faceswap.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "max_distance": 0.65,
3
+ "swap_mode": "selected",
4
+ "v": "1.0"
5
+ }
plugin_options/plugin_gfpgan.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "v": "1.4"
3
+ }
plugin_options/plugin_txt2clip.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "v": "1.0"
3
+ }
plugins/codeformer_app_cv2.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified version from codeformer-pip project
3
+
4
+ S-Lab License 1.0
5
+
6
+ Copyright 2022 S-Lab
7
+
8
+ https://github.com/kadirnar/codeformer-pip/blob/main/LICENSE
9
+ """
10
+
11
+ import os
12
+
13
+ import cv2
14
+ import torch
15
+ from codeformer.facelib.detection import init_detection_model
16
+ from codeformer.facelib.parsing import init_parsing_model
17
+ from torchvision.transforms.functional import normalize
18
+
19
+ from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet
20
+ from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img
21
+ from codeformer.basicsr.utils.download_util import load_file_from_url
22
+ from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer
23
+ from codeformer.basicsr.utils.registry import ARCH_REGISTRY
24
+ from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
25
+ from codeformer.facelib.utils.misc import is_gray
26
+ import threading
27
+
28
+ from plugins.codeformer_face_helper_cv2 import FaceRestoreHelperOptimized
29
+
30
+ THREAD_LOCK_FACE_HELPER = threading.Lock()
31
+ THREAD_LOCK_FACE_HELPER_CREATE = threading.Lock()
32
+ THREAD_LOCK_FACE_HELPER_PROCERSSING = threading.Lock()
33
+ THREAD_LOCK_CODEFORMER_NET = threading.Lock()
34
+ THREAD_LOCK_CODEFORMER_NET_CREATE = threading.Lock()
35
+ THREAD_LOCK_BGUPSAMPLER = threading.Lock()
36
+
37
+ pretrain_model_url = {
38
+ "codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
39
+ "detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
40
+ "parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
41
+ "realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
42
+ }
43
+
44
+ # download weights
45
+ if not os.path.exists("models/CodeFormer/codeformer.pth"):
46
+ load_file_from_url(
47
+ url=pretrain_model_url["codeformer"], model_dir="models/CodeFormer/", progress=True, file_name=None
48
+ )
49
+ if not os.path.exists("models/CodeFormer/facelib/detection_Resnet50_Final.pth"):
50
+ load_file_from_url(
51
+ url=pretrain_model_url["detection"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
52
+ )
53
+ if not os.path.exists("models/CodeFormer/facelib/parsing_parsenet.pth"):
54
+ load_file_from_url(
55
+ url=pretrain_model_url["parsing"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
56
+ )
57
+ if not os.path.exists("models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth"):
58
+ load_file_from_url(
59
+ url=pretrain_model_url["realesrgan"], model_dir="models/CodeFormer/realesrgan", progress=True, file_name=None
60
+ )
61
+
62
+
63
+ def imread(img_path):
64
+ img = cv2.imread(img_path)
65
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
66
+ return img
67
+
68
+
69
+ # set enhancer with RealESRGAN
70
+ def set_realesrgan():
71
+ half = True if torch.cuda.is_available() else False
72
+ model = RRDBNet(
73
+ num_in_ch=3,
74
+ num_out_ch=3,
75
+ num_feat=64,
76
+ num_block=23,
77
+ num_grow_ch=32,
78
+ scale=2,
79
+ )
80
+ upsampler = RealESRGANer(
81
+ scale=2,
82
+ model_path="models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth",
83
+ model=model,
84
+ tile=400,
85
+ tile_pad=40,
86
+ pre_pad=0,
87
+ half=half,
88
+ )
89
+ return upsampler
90
+
91
+
92
+ upsampler = set_realesrgan()
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+
95
+ codeformers_cache = []
96
+
97
+ def get_codeformer():
98
+ if len(codeformers_cache) > 0:
99
+ with THREAD_LOCK_CODEFORMER_NET:
100
+ if len(codeformers_cache) > 0:
101
+ return codeformers_cache.pop()
102
+
103
+ with THREAD_LOCK_CODEFORMER_NET_CREATE:
104
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
105
+ dim_embd=512,
106
+ codebook_size=1024,
107
+ n_head=8,
108
+ n_layers=9,
109
+ connect_list=["32", "64", "128", "256"],
110
+ ).to(device)
111
+ ckpt_path = "models/CodeFormer/codeformer.pth"
112
+ checkpoint = torch.load(ckpt_path)["params_ema"]
113
+ codeformer_net.load_state_dict(checkpoint)
114
+ codeformer_net.eval()
115
+ return codeformer_net
116
+
117
+
118
+
119
+ def release_codeformer(codeformer):
120
+ with THREAD_LOCK_CODEFORMER_NET:
121
+ codeformers_cache.append(codeformer)
122
+
123
+ #os.makedirs("output", exist_ok=True)
124
+
125
+ # ------- face restore thread cache ----------
126
+
127
+ face_restore_helper_cache = []
128
+
129
+ detection_model = "retinaface_resnet50"
130
+
131
+ inited_face_restore_helper_nn = False
132
+
133
+ import time
134
+
135
+ def get_face_restore_helper(upscale):
136
+ global inited_face_restore_helper_nn
137
+ with THREAD_LOCK_FACE_HELPER:
138
+ face_helper = FaceRestoreHelperOptimized(
139
+ upscale,
140
+ face_size=512,
141
+ crop_ratio=(1, 1),
142
+ det_model=detection_model,
143
+ save_ext="png",
144
+ use_parse=True,
145
+ device=device,
146
+ )
147
+ #return face_helper
148
+
149
+ if inited_face_restore_helper_nn:
150
+ while len(face_restore_helper_cache) == 0:
151
+ time.sleep(0.05)
152
+ face_detector, face_parse = face_restore_helper_cache.pop()
153
+ face_helper.face_detector = face_detector
154
+ face_helper.face_parse = face_parse
155
+ return face_helper
156
+ else:
157
+ inited_face_restore_helper_nn = True
158
+ face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
159
+ face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
160
+ return face_helper
161
+
162
+ def get_face_restore_helper2(upscale): # still not work well!!!
163
+ face_helper = FaceRestoreHelperOptimized(
164
+ upscale,
165
+ face_size=512,
166
+ crop_ratio=(1, 1),
167
+ det_model=detection_model,
168
+ save_ext="png",
169
+ use_parse=True,
170
+ device=device,
171
+ )
172
+ #return face_helper
173
+
174
+ if len(face_restore_helper_cache) > 0:
175
+ with THREAD_LOCK_FACE_HELPER:
176
+ if len(face_restore_helper_cache) > 0:
177
+ face_detector, face_parse = face_restore_helper_cache.pop()
178
+ face_helper.face_detector = face_detector
179
+ face_helper.face_parse = face_parse
180
+ return face_helper
181
+
182
+ with THREAD_LOCK_FACE_HELPER_CREATE:
183
+ face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
184
+ face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
185
+ return face_helper
186
+
187
+ def release_face_restore_helper(face_helper):
188
+ #return
189
+ #with THREAD_LOCK_FACE_HELPER:
190
+ face_restore_helper_cache.append((face_helper.face_detector, face_helper.face_parse))
191
+ #pass
192
+
193
+ def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity, skip_if_no_face = False):
194
+ # take the default setting for the demo
195
+ has_aligned = False
196
+ only_center_face = False
197
+ draw_box = False
198
+
199
+ #print("Inp:", image, background_enhance, face_upsample, upscale, codeformer_fidelity)
200
+ if isinstance(image, str):
201
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
202
+ else:
203
+ img = image
204
+ #print("\timage size:", img.shape)
205
+
206
+ upscale = int(upscale) # convert type to int
207
+ if upscale > 4: # avoid memory exceeded due to too large upscale
208
+ upscale = 4
209
+ if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution
210
+ upscale = 2
211
+ if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
212
+ upscale = 1
213
+ background_enhance = False
214
+ #face_upsample = False
215
+
216
+ face_helper = get_face_restore_helper(upscale)
217
+
218
+ bg_upsampler = upsampler if background_enhance else None
219
+ face_upsampler = upsampler if face_upsample else None
220
+
221
+ if has_aligned:
222
+ # the input faces are already cropped and aligned
223
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
224
+ face_helper.is_gray = is_gray(img, threshold=5)
225
+ if face_helper.is_gray:
226
+ print("\tgrayscale input: True")
227
+ face_helper.cropped_faces = [img]
228
+ else:
229
+ with THREAD_LOCK_FACE_HELPER_PROCERSSING:
230
+ face_helper.read_image(img)
231
+ # get face landmarks for each face
232
+
233
+ num_det_faces = face_helper.get_face_landmarks_5(
234
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
235
+ )
236
+ #print(f"\tdetect {num_det_faces} faces")
237
+
238
+ if num_det_faces == 0 and skip_if_no_face:
239
+ release_face_restore_helper(face_helper)
240
+ return img
241
+
242
+ # align and warp each face
243
+ face_helper.align_warp_face()
244
+
245
+
246
+
247
+ # face restoration for each cropped face
248
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
249
+ # prepare data
250
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
251
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
252
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
253
+
254
+ codeformer_net = get_codeformer()
255
+ try:
256
+ with torch.no_grad():
257
+ output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
258
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
259
+ del output
260
+ except RuntimeError as error:
261
+ print(f"Failed inference for CodeFormer: {error}")
262
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
263
+ release_codeformer(codeformer_net)
264
+
265
+ restored_face = restored_face.astype("uint8")
266
+ face_helper.add_restored_face(restored_face)
267
+
268
+ # paste_back
269
+ if not has_aligned:
270
+ # upsample the background
271
+ if bg_upsampler is not None:
272
+ with THREAD_LOCK_BGUPSAMPLER:
273
+ # Now only support RealESRGAN for upsampling background
274
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
275
+ else:
276
+ bg_img = None
277
+ face_helper.get_inverse_affine(None)
278
+ # paste each restored face to the input image
279
+ if face_upsample and face_upsampler is not None:
280
+ restored_img = face_helper.paste_faces_to_input_image(
281
+ upsample_img=bg_img,
282
+ draw_box=draw_box,
283
+ face_upsampler=face_upsampler,
284
+ )
285
+ else:
286
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
287
+
288
+ if image.shape != restored_img.shape:
289
+ h, w, _ = image.shape
290
+ restored_img = cv2.resize(restored_img, (w, h), interpolation=cv2.INTER_LINEAR)
291
+
292
+
293
+ release_face_restore_helper(face_helper)
294
+ # save restored img
295
+ if isinstance(image, str):
296
+ save_path = f"output/out.png"
297
+ imwrite(restored_img, str(save_path))
298
+ return save_path
299
+ else:
300
+ return restored_img
plugins/codeformer_face_helper_cv2.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
2
+
3
+ import numpy as np
4
+ from codeformer.basicsr.utils.misc import get_device
5
+
6
+ class FaceRestoreHelperOptimized(FaceRestoreHelper):
7
+ def __init__(
8
+ self,
9
+ upscale_factor,
10
+ face_size=512,
11
+ crop_ratio=(1, 1),
12
+ det_model="retinaface_resnet50",
13
+ save_ext="png",
14
+ template_3points=False,
15
+ pad_blur=False,
16
+ use_parse=False,
17
+ device=None,
18
+ ):
19
+ self.template_3points = template_3points # improve robustness
20
+ self.upscale_factor = int(upscale_factor)
21
+ # the cropped face ratio based on the square face
22
+ self.crop_ratio = crop_ratio # (h, w)
23
+ assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1"
24
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
25
+ self.det_model = det_model
26
+
27
+ if self.det_model == "dlib":
28
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
29
+ self.face_template = np.array(
30
+ [
31
+ [686.77227723, 488.62376238],
32
+ [586.77227723, 493.59405941],
33
+ [337.91089109, 488.38613861],
34
+ [437.95049505, 493.51485149],
35
+ [513.58415842, 678.5049505],
36
+ ]
37
+ )
38
+ self.face_template = self.face_template / (1024 // face_size)
39
+ elif self.template_3points:
40
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
41
+ else:
42
+ # standard 5 landmarks for FFHQ faces with 512 x 512
43
+ # facexlib
44
+ self.face_template = np.array(
45
+ [
46
+ [192.98138, 239.94708],
47
+ [318.90277, 240.1936],
48
+ [256.63416, 314.01935],
49
+ [201.26117, 371.41043],
50
+ [313.08905, 371.15118],
51
+ ]
52
+ )
53
+
54
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
55
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
56
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
57
+
58
+ self.face_template = self.face_template * (face_size / 512.0)
59
+ if self.crop_ratio[0] > 1:
60
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
61
+ if self.crop_ratio[1] > 1:
62
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
63
+ self.save_ext = save_ext
64
+ self.pad_blur = pad_blur
65
+ if self.pad_blur is True:
66
+ self.template_3points = False
67
+
68
+ self.all_landmarks_5 = []
69
+ self.det_faces = []
70
+ self.affine_matrices = []
71
+ self.inverse_affine_matrices = []
72
+ self.cropped_faces = []
73
+ self.restored_faces = []
74
+ self.pad_input_imgs = []
75
+
76
+ if device is None:
77
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
+ self.device = get_device()
79
+ else:
80
+ self.device = device
81
+
82
+ # init face detection model
83
+ # if self.det_model == "dlib":
84
+ # self.face_detector, self.shape_predictor_5 = self.init_dlib(
85
+ # dlib_model_url["face_detector"], dlib_model_url["shape_predictor_5"]
86
+ # )
87
+ # else:
88
+ # self.face_detector = init_detection_model(det_model, half=False, device=self.device)
89
+
90
+ # init face parsing model
91
+ self.use_parse = use_parse
92
+ #self.face_parse = init_parsing_model(model_name="parsenet", device=self.device)
93
+
94
+ # MUST set face_detector and face_parse!!!
plugins/core.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core plugin
2
+ # author: Vladislav Janvarev
3
+
4
+ from chain_img_processor import ChainImgProcessor
5
+
6
+ # start function
7
+ def start(core:ChainImgProcessor):
8
+ manifest = {
9
+ "name": "Core plugin",
10
+ "version": "2.0",
11
+
12
+ "default_options": {
13
+ "default_chain": "faceswap", # default chain to run
14
+ "init_on_start": "faceswap,txt2clip,gfpgan,codeformer", # init these processors on start
15
+ "is_demo_row_render": False,
16
+ },
17
+
18
+ }
19
+ return manifest
20
+
21
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
22
+ options = manifest["options"]
23
+
24
+ core.default_chain = options["default_chain"]
25
+ core.init_on_start = options["init_on_start"]
26
+
27
+ core.is_demo_row_render= options["is_demo_row_render"]
28
+
29
+ return manifest
plugins/core_video.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core plugin
2
+ # author: Vladislav Janvarev
3
+
4
+ from chain_img_processor import ChainImgProcessor, ChainVideoProcessor
5
+
6
+ # start function
7
+ def start(core:ChainImgProcessor):
8
+ manifest = {
9
+ "name": "Core video plugin",
10
+ "version": "2.0",
11
+
12
+ "default_options": {
13
+ "video_save_codec": "libx264", # default codec to save
14
+ "video_save_crf": 14, # default crf to save
15
+ },
16
+
17
+ }
18
+ return manifest
19
+
20
+ def start_with_options(core:ChainVideoProcessor, manifest:dict):
21
+ options = manifest["options"]
22
+
23
+ core.video_save_codec = options["video_save_codec"]
24
+ core.video_save_crf = options["video_save_crf"]
25
+
26
+ return manifest
plugins/plugin_codeformer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codeformer enchance plugin
2
+ # author: Vladislav Janvarev
3
+
4
+ # CountFloyd 20230717, extended to blend original/destination images
5
+
6
+ from chain_img_processor import ChainImgProcessor, ChainImgPlugin
7
+ import os
8
+ from PIL import Image
9
+ from numpy import asarray
10
+
11
+ modname = os.path.basename(__file__)[:-3] # calculating modname
12
+
13
+ # start function
14
+ def start(core:ChainImgProcessor):
15
+ manifest = { # plugin settings
16
+ "name": "Codeformer", # name
17
+ "version": "3.0", # version
18
+
19
+ "default_options": {
20
+ "background_enhance": True, #
21
+ "face_upsample": True, #
22
+ "upscale": 2, #
23
+ "codeformer_fidelity": 0.8,
24
+ "skip_if_no_face":False,
25
+
26
+ },
27
+
28
+ "img_processor": {
29
+ "codeformer": PluginCodeformer # 1 function - init, 2 - process
30
+ }
31
+ }
32
+ return manifest
33
+
34
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
35
+ pass
36
+
37
+ class PluginCodeformer(ChainImgPlugin):
38
+ def init_plugin(self):
39
+ import plugins.codeformer_app_cv2
40
+ pass
41
+
42
+ def process(self, img, params:dict):
43
+ import copy
44
+
45
+ # params can be used to transfer some img info to next processors
46
+ from plugins.codeformer_app_cv2 import inference_app
47
+ options = self.core.plugin_options(modname)
48
+
49
+ if "face_detected" in params:
50
+ if not params["face_detected"]:
51
+ return img
52
+
53
+ # don't touch original
54
+ temp_frame = copy.copy(img)
55
+ if "processed_faces" in params:
56
+ for face in params["processed_faces"]:
57
+ start_x, start_y, end_x, end_y = map(int, face['bbox'])
58
+ padding_x = int((end_x - start_x) * 0.5)
59
+ padding_y = int((end_y - start_y) * 0.5)
60
+ start_x = max(0, start_x - padding_x)
61
+ start_y = max(0, start_y - padding_y)
62
+ end_x = max(0, end_x + padding_x)
63
+ end_y = max(0, end_y + padding_y)
64
+ temp_face = temp_frame[start_y:end_y, start_x:end_x]
65
+ if temp_face.size:
66
+ temp_face = inference_app(temp_face, options.get("background_enhance"), options.get("face_upsample"),
67
+ options.get("upscale"), options.get("codeformer_fidelity"),
68
+ options.get("skip_if_no_face"))
69
+ temp_frame[start_y:end_y, start_x:end_x] = temp_face
70
+ else:
71
+ temp_frame = inference_app(temp_frame, options.get("background_enhance"), options.get("face_upsample"),
72
+ options.get("upscale"), options.get("codeformer_fidelity"),
73
+ options.get("skip_if_no_face"))
74
+
75
+
76
+
77
+ if not "blend_ratio" in params:
78
+ return temp_frame
79
+
80
+
81
+ temp_frame = Image.blend(Image.fromarray(img), Image.fromarray(temp_frame), params["blend_ratio"])
82
+ return asarray(temp_frame)
83
+
plugins/plugin_dmdnet.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chain_img_processor import ChainImgProcessor, ChainImgPlugin
2
+ import os
3
+ from PIL import Image
4
+ from numpy import asarray
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import scipy.io as sio
10
+ import numpy as np
11
+ import torch.nn.utils.spectral_norm as SpectralNorm
12
+ from torchvision.ops import roi_align
13
+
14
+ from math import sqrt
15
+ import os
16
+
17
+ import cv2
18
+ import os
19
+ from torchvision.transforms.functional import normalize
20
+ import copy
21
+ import threading
22
+
23
+ modname = os.path.basename(__file__)[:-3] # calculating modname
24
+
25
+ oDMDNet = None
26
+ device = None
27
+
28
+ THREAD_LOCK_DMDNET = threading.Lock()
29
+
30
+
31
+
32
+ # start function
33
+ def start(core:ChainImgProcessor):
34
+ manifest = { # plugin settings
35
+ "name": "DMDNet", # name
36
+ "version": "1.0", # version
37
+
38
+ "default_options": {},
39
+ "img_processor": {
40
+ "dmdnet": DMDNETPlugin
41
+ }
42
+ }
43
+ return manifest
44
+
45
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
46
+ pass
47
+
48
+
49
+ class DMDNETPlugin(ChainImgPlugin):
50
+
51
+ # https://stackoverflow.com/a/67174339
52
+ def landmarks106_to_68(self, pt106):
53
+ map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
54
+ 43,48,49,51,50,
55
+ 102,103,104,105,101,
56
+ 72,73,74,86,78,79,80,85,84,
57
+ 35,41,42,39,37,36,
58
+ 89,95,96,93,91,90,
59
+ 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
60
+ ]
61
+
62
+ pt68 = []
63
+ for i in range(68):
64
+ index = map106to68[i]
65
+ pt68.append(pt106[index])
66
+ return pt68
67
+
68
+ def init_plugin(self):
69
+ global create
70
+
71
+ if oDMDNet == None:
72
+ create(self.device)
73
+
74
+
75
+ def process(self, frame, params:dict):
76
+ if "face_detected" in params:
77
+ if not params["face_detected"]:
78
+ return frame
79
+
80
+ temp_frame = copy.copy(frame)
81
+ if "processed_faces" in params:
82
+ for face in params["processed_faces"]:
83
+ start_x, start_y, end_x, end_y = map(int, face['bbox'])
84
+ # padding_x = int((end_x - start_x) * 0.5)
85
+ # padding_y = int((end_y - start_y) * 0.5)
86
+ padding_x = 0
87
+ padding_y = 0
88
+
89
+ start_x = max(0, start_x - padding_x)
90
+ start_y = max(0, start_y - padding_y)
91
+ end_x = max(0, end_x + padding_x)
92
+ end_y = max(0, end_y + padding_y)
93
+ temp_face = temp_frame[start_y:end_y, start_x:end_x]
94
+ if temp_face.size:
95
+ temp_face = self.enhance_face(temp_face, face)
96
+ temp_face = cv2.resize(temp_face, (end_x - start_x,end_y - start_y), interpolation = cv2.INTER_LANCZOS4)
97
+ temp_frame[start_y:end_y, start_x:end_x] = temp_face
98
+
99
+ temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
100
+ return asarray(temp_frame)
101
+
102
+
103
+ def enhance_face(self, clip, face):
104
+ global device
105
+
106
+ lm106 = face.landmark_2d_106
107
+ lq_landmarks = asarray(self.landmarks106_to_68(lm106))
108
+ lq = read_img_tensor(clip, False)
109
+
110
+ LQLocs = get_component_location(lq_landmarks)
111
+ # generic
112
+ SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
113
+
114
+ with torch.no_grad():
115
+ with THREAD_LOCK_DMDNET:
116
+ try:
117
+ GenericResult, SpecificResult = oDMDNet(lq = lq.to(device), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
118
+ except Exception as e:
119
+ print(f'Error {e} there may be something wrong with the detected component locations.')
120
+ return clip
121
+ save_generic = GenericResult * 0.5 + 0.5
122
+ save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
123
+ save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
124
+
125
+ check_lq = lq * 0.5 + 0.5
126
+ check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
127
+ check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
128
+ enhanced_img = np.hstack((check_lq, save_generic))
129
+ temp_frame = save_generic.astype("uint8")
130
+ # temp_frame = save_generic.astype("uint8")
131
+ return temp_frame
132
+
133
+
134
+ def create(devicename):
135
+ global device, oDMDNet
136
+
137
+ test = "cuda" if torch.cuda.is_available() else "cpu"
138
+ device = torch.device(devicename)
139
+ oDMDNet = DMDNet().to(device)
140
+ weights = torch.load('./models/DMDNet.pth')
141
+ oDMDNet.load_state_dict(weights, strict=True)
142
+
143
+ oDMDNet.eval()
144
+ num_params = 0
145
+ for param in oDMDNet.parameters():
146
+ num_params += param.numel()
147
+
148
+ # print('{:>8s} : {}'.format('Using device', device))
149
+ # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
150
+
151
+
152
+
153
+ def read_img_tensor(Img=None, return_landmark=True): #rgb -1~1
154
+ # Img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # BGR or G
155
+ if Img.ndim == 2:
156
+ Img = cv2.cvtColor(Img, cv2.COLOR_GRAY2RGB) # GGG
157
+ else:
158
+ Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) # RGB
159
+
160
+ if Img.shape[0] < 512 or Img.shape[1] < 512:
161
+ Img = cv2.resize(Img, (512,512), interpolation = cv2.INTER_AREA)
162
+ # ImgForLands = Img.copy()
163
+
164
+ Img = Img.transpose((2, 0, 1))/255.0
165
+ Img = torch.from_numpy(Img).float()
166
+ normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
167
+ ImgTensor = Img.unsqueeze(0)
168
+ return ImgTensor
169
+
170
+
171
+ def get_component_location(Landmarks, re_read=False):
172
+ if re_read:
173
+ ReadLandmark = []
174
+ with open(Landmarks,'r') as f:
175
+ for line in f:
176
+ tmp = [float(i) for i in line.split(' ') if i != '\n']
177
+ ReadLandmark.append(tmp)
178
+ ReadLandmark = np.array(ReadLandmark) #
179
+ Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
180
+ Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
181
+ Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
182
+ Map_LE = list(range(36,42))
183
+ Map_RE = list(range(42,48))
184
+ Map_NO = list(range(29,36))
185
+ Map_MO = list(range(48,68))
186
+
187
+ Landmarks[Landmarks>504]=504
188
+ Landmarks[Landmarks<8]=8
189
+
190
+ #left eye
191
+ Mean_LE = np.mean(Landmarks[Map_LE],0)
192
+ L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
193
+ L_LE1 = L_LE1 * 1.3
194
+ L_LE2 = L_LE1 / 1.9
195
+ L_LE_xy = L_LE1 + L_LE2
196
+ L_LE_lt = [L_LE_xy/2, L_LE1]
197
+ L_LE_rb = [L_LE_xy/2, L_LE2]
198
+ Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
199
+
200
+ #right eye
201
+ Mean_RE = np.mean(Landmarks[Map_RE],0)
202
+ L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
203
+ L_RE1 = L_RE1 * 1.3
204
+ L_RE2 = L_RE1 / 1.9
205
+ L_RE_xy = L_RE1 + L_RE2
206
+ L_RE_lt = [L_RE_xy/2, L_RE1]
207
+ L_RE_rb = [L_RE_xy/2, L_RE2]
208
+ Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
209
+
210
+ #nose
211
+ Mean_NO = np.mean(Landmarks[Map_NO],0)
212
+ L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
213
+ L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
214
+ L_NO_xy = L_NO1 * 2
215
+ L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
216
+ L_NO_rb = [L_NO_xy/2, L_NO2]
217
+ Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
218
+
219
+ #mouth
220
+ Mean_MO = np.mean(Landmarks[Map_MO],0)
221
+ L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
222
+ MO_O = Mean_MO - L_MO + 1
223
+ MO_T = Mean_MO + L_MO
224
+ MO_T[MO_T>510]=510
225
+ Location_MO = np.hstack((MO_O, MO_T)).astype(int)
226
+ return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
227
+
228
+
229
+
230
+
231
+ def calc_mean_std_4D(feat, eps=1e-5):
232
+ # eps is a small value added to the variance to avoid divide-by-zero.
233
+ size = feat.size()
234
+ assert (len(size) == 4)
235
+ N, C = size[:2]
236
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
237
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
238
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
239
+ return feat_mean, feat_std
240
+
241
+ def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
242
+ size = content_feat.size()
243
+ style_mean, style_std = calc_mean_std_4D(style_feat)
244
+ content_mean, content_std = calc_mean_std_4D(content_feat)
245
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
246
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
247
+
248
+
249
+ def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
250
+ return nn.Sequential(
251
+ SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
252
+ nn.LeakyReLU(0.2),
253
+ SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
254
+ )
255
+
256
+
257
+ class MSDilateBlock(nn.Module):
258
+ def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
259
+ super(MSDilateBlock, self).__init__()
260
+ self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
261
+ self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
262
+ self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
263
+ self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
264
+ self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
265
+ def forward(self, x):
266
+ conv1 = self.conv1(x)
267
+ conv2 = self.conv2(x)
268
+ conv3 = self.conv3(x)
269
+ conv4 = self.conv4(x)
270
+ cat = torch.cat([conv1, conv2, conv3, conv4], 1)
271
+ out = self.convi(cat) + x
272
+ return out
273
+
274
+
275
+ class AdaptiveInstanceNorm(nn.Module):
276
+ def __init__(self, in_channel):
277
+ super().__init__()
278
+ self.norm = nn.InstanceNorm2d(in_channel)
279
+
280
+ def forward(self, input, style):
281
+ style_mean, style_std = calc_mean_std_4D(style)
282
+ out = self.norm(input)
283
+ size = input.size()
284
+ out = style_std.expand(size) * out + style_mean.expand(size)
285
+ return out
286
+
287
+ class NoiseInjection(nn.Module):
288
+ def __init__(self, channel):
289
+ super().__init__()
290
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
291
+ def forward(self, image, noise):
292
+ if noise is None:
293
+ b, c, h, w = image.shape
294
+ noise = image.new_empty(b, 1, h, w).normal_()
295
+ return image + self.weight * noise
296
+
297
+ class StyledUpBlock(nn.Module):
298
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
299
+ super().__init__()
300
+
301
+ self.noise_inject = noise_inject
302
+ if upsample:
303
+ self.conv1 = nn.Sequential(
304
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
305
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
306
+ nn.LeakyReLU(0.2),
307
+ )
308
+ else:
309
+ self.conv1 = nn.Sequential(
310
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
311
+ nn.LeakyReLU(0.2),
312
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
313
+ )
314
+ self.convup = nn.Sequential(
315
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
316
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
317
+ nn.LeakyReLU(0.2),
318
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
319
+ )
320
+ if self.noise_inject:
321
+ self.noise1 = NoiseInjection(out_channel)
322
+
323
+ self.lrelu1 = nn.LeakyReLU(0.2)
324
+
325
+ self.ScaleModel1 = nn.Sequential(
326
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
327
+ nn.LeakyReLU(0.2),
328
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
329
+ )
330
+ self.ShiftModel1 = nn.Sequential(
331
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
332
+ nn.LeakyReLU(0.2),
333
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
334
+ )
335
+
336
+ def forward(self, input, style):
337
+ out = self.conv1(input)
338
+ out = self.lrelu1(out)
339
+ Shift1 = self.ShiftModel1(style)
340
+ Scale1 = self.ScaleModel1(style)
341
+ out = out * Scale1 + Shift1
342
+ if self.noise_inject:
343
+ out = self.noise1(out, noise=None)
344
+ outup = self.convup(out)
345
+ return outup
346
+
347
+
348
+ ####################################################################
349
+ ###############Face Dictionary Generator
350
+ ####################################################################
351
+ def AttentionBlock(in_channel):
352
+ return nn.Sequential(
353
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
354
+ nn.LeakyReLU(0.2),
355
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
356
+ )
357
+
358
+ class DilateResBlock(nn.Module):
359
+ def __init__(self, dim, dilation=[5,3] ):
360
+ super(DilateResBlock, self).__init__()
361
+ self.Res = nn.Sequential(
362
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
363
+ nn.LeakyReLU(0.2),
364
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
365
+ )
366
+ def forward(self, x):
367
+ out = x + self.Res(x)
368
+ return out
369
+
370
+
371
+ class KeyValue(nn.Module):
372
+ def __init__(self, indim, keydim, valdim):
373
+ super(KeyValue, self).__init__()
374
+ self.Key = nn.Sequential(
375
+ SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
376
+ nn.LeakyReLU(0.2),
377
+ SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
378
+ )
379
+ self.Value = nn.Sequential(
380
+ SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
381
+ nn.LeakyReLU(0.2),
382
+ SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
383
+ )
384
+ def forward(self, x):
385
+ return self.Key(x), self.Value(x)
386
+
387
+ class MaskAttention(nn.Module):
388
+ def __init__(self, indim):
389
+ super(MaskAttention, self).__init__()
390
+ self.conv1 = nn.Sequential(
391
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
392
+ nn.LeakyReLU(0.2),
393
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
394
+ )
395
+ self.conv2 = nn.Sequential(
396
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
397
+ nn.LeakyReLU(0.2),
398
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
399
+ )
400
+ self.conv3 = nn.Sequential(
401
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
402
+ nn.LeakyReLU(0.2),
403
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
404
+ )
405
+ self.convCat = nn.Sequential(
406
+ SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
407
+ nn.LeakyReLU(0.2),
408
+ SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
409
+ )
410
+ def forward(self, x, y, z):
411
+ c1 = self.conv1(x)
412
+ c2 = self.conv2(y)
413
+ c3 = self.conv3(z)
414
+ return self.convCat(torch.cat([c1,c2,c3], dim=1))
415
+
416
+ class Query(nn.Module):
417
+ def __init__(self, indim, quedim):
418
+ super(Query, self).__init__()
419
+ self.Query = nn.Sequential(
420
+ SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
421
+ nn.LeakyReLU(0.2),
422
+ SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
423
+ )
424
+ def forward(self, x):
425
+ return self.Query(x)
426
+
427
+ def roi_align_self(input, location, target_size):
428
+ return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],(target_size,target_size),mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
429
+
430
+ class FeatureExtractor(nn.Module):
431
+ def __init__(self, ngf = 64, key_scale = 4):#
432
+ super().__init__()
433
+
434
+ self.key_scale = 4
435
+ self.part_sizes = np.array([80,80,50,110]) #
436
+ self.feature_sizes = np.array([256,128,64]) #
437
+
438
+ self.conv1 = nn.Sequential(
439
+ SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
440
+ nn.LeakyReLU(0.2),
441
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
442
+ )
443
+ self.conv2 = nn.Sequential(
444
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
445
+ nn.LeakyReLU(0.2),
446
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
447
+ )
448
+ self.res1 = DilateResBlock(ngf, [5,3])
449
+ self.res2 = DilateResBlock(ngf, [5,3])
450
+
451
+
452
+ self.conv3 = nn.Sequential(
453
+ SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
454
+ nn.LeakyReLU(0.2),
455
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
456
+ )
457
+ self.conv4 = nn.Sequential(
458
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
459
+ nn.LeakyReLU(0.2),
460
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
461
+ )
462
+ self.res3 = DilateResBlock(ngf*2, [3,1])
463
+ self.res4 = DilateResBlock(ngf*2, [3,1])
464
+
465
+ self.conv5 = nn.Sequential(
466
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
467
+ nn.LeakyReLU(0.2),
468
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
469
+ )
470
+ self.conv6 = nn.Sequential(
471
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
472
+ nn.LeakyReLU(0.2),
473
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
474
+ )
475
+ self.res5 = DilateResBlock(ngf*4, [1,1])
476
+ self.res6 = DilateResBlock(ngf*4, [1,1])
477
+
478
+ self.LE_256_Q = Query(ngf, ngf // self.key_scale)
479
+ self.RE_256_Q = Query(ngf, ngf // self.key_scale)
480
+ self.MO_256_Q = Query(ngf, ngf // self.key_scale)
481
+ self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
482
+ self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
483
+ self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
484
+ self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
485
+ self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
486
+ self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
487
+
488
+
489
+ def forward(self, img, locs):
490
+ le_location = locs[:,0,:].int().cpu().numpy()
491
+ re_location = locs[:,1,:].int().cpu().numpy()
492
+ no_location = locs[:,2,:].int().cpu().numpy()
493
+ mo_location = locs[:,3,:].int().cpu().numpy()
494
+
495
+
496
+ f1_0 = self.conv1(img)
497
+ f1_1 = self.res1(f1_0)
498
+ f2_0 = self.conv2(f1_1)
499
+ f2_1 = self.res2(f2_0)
500
+
501
+ f3_0 = self.conv3(f2_1)
502
+ f3_1 = self.res3(f3_0)
503
+ f4_0 = self.conv4(f3_1)
504
+ f4_1 = self.res4(f4_0)
505
+
506
+ f5_0 = self.conv5(f4_1)
507
+ f5_1 = self.res5(f5_0)
508
+ f6_0 = self.conv6(f5_1)
509
+ f6_1 = self.res6(f6_0)
510
+
511
+
512
+ ####ROI Align
513
+ le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
514
+ re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
515
+ mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
516
+
517
+ le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
518
+ re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
519
+ mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
520
+
521
+ le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
522
+ re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
523
+ mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
524
+
525
+
526
+ le_256_q = self.LE_256_Q(le_part_256)
527
+ re_256_q = self.RE_256_Q(re_part_256)
528
+ mo_256_q = self.MO_256_Q(mo_part_256)
529
+
530
+ le_128_q = self.LE_128_Q(le_part_128)
531
+ re_128_q = self.RE_128_Q(re_part_128)
532
+ mo_128_q = self.MO_128_Q(mo_part_128)
533
+
534
+ le_64_q = self.LE_64_Q(le_part_64)
535
+ re_64_q = self.RE_64_Q(re_part_64)
536
+ mo_64_q = self.MO_64_Q(mo_part_64)
537
+
538
+ return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
539
+ 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
540
+ 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
541
+ 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
542
+ 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
543
+ 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
544
+ 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
545
+
546
+
547
+ class DMDNet(nn.Module):
548
+ def __init__(self, ngf = 64, banks_num = 128):
549
+ super().__init__()
550
+ self.part_sizes = np.array([80,80,50,110]) # size for 512
551
+ self.feature_sizes = np.array([256,128,64]) # size for 512
552
+
553
+ self.banks_num = banks_num
554
+ self.key_scale = 4
555
+
556
+ self.E_lq = FeatureExtractor(key_scale = self.key_scale)
557
+ self.E_hq = FeatureExtractor(key_scale = self.key_scale)
558
+
559
+ self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
560
+ self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
561
+ self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
562
+
563
+ self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
564
+ self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
565
+ self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
566
+
567
+ self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
568
+ self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
569
+ self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
570
+
571
+
572
+ self.LE_256_Attention = AttentionBlock(64)
573
+ self.RE_256_Attention = AttentionBlock(64)
574
+ self.MO_256_Attention = AttentionBlock(64)
575
+
576
+ self.LE_128_Attention = AttentionBlock(128)
577
+ self.RE_128_Attention = AttentionBlock(128)
578
+ self.MO_128_Attention = AttentionBlock(128)
579
+
580
+ self.LE_64_Attention = AttentionBlock(256)
581
+ self.RE_64_Attention = AttentionBlock(256)
582
+ self.MO_64_Attention = AttentionBlock(256)
583
+
584
+ self.LE_256_Mask = MaskAttention(64)
585
+ self.RE_256_Mask = MaskAttention(64)
586
+ self.MO_256_Mask = MaskAttention(64)
587
+
588
+ self.LE_128_Mask = MaskAttention(128)
589
+ self.RE_128_Mask = MaskAttention(128)
590
+ self.MO_128_Mask = MaskAttention(128)
591
+
592
+ self.LE_64_Mask = MaskAttention(256)
593
+ self.RE_64_Mask = MaskAttention(256)
594
+ self.MO_64_Mask = MaskAttention(256)
595
+
596
+ self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
597
+
598
+ self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
599
+ self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
600
+ self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
601
+ self.up4 = nn.Sequential(
602
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
603
+ nn.LeakyReLU(0.2),
604
+ UpResBlock(ngf),
605
+ UpResBlock(ngf),
606
+ SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
607
+ nn.Tanh()
608
+ )
609
+
610
+ # define generic memory, revise register_buffer to register_parameter for backward update
611
+ self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
612
+ self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
613
+ self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
614
+ self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
615
+ self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
616
+ self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
617
+
618
+
619
+ self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
620
+ self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
621
+ self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
622
+ self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
623
+ self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
624
+ self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
625
+
626
+ self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
627
+ self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
628
+ self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
629
+ self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
630
+ self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
631
+ self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
632
+
633
+
634
+ def readMem(self, k, v, q):
635
+ sim = F.conv2d(q, k)
636
+ score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
637
+ sb,sn,sw,sh = score.size()
638
+ s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
639
+ vb,vn,vw,vh = v.size()
640
+ v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
641
+ mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
642
+ max_inds = torch.argmax(score, dim=1).squeeze()
643
+ return mem_out, max_inds
644
+
645
+
646
+ def memorize(self, img, locs):
647
+ fs = self.E_hq(img, locs)
648
+ LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
649
+ RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
650
+ MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
651
+
652
+ LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
653
+ RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
654
+ MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
655
+
656
+ LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
657
+ RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
658
+ MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
659
+
660
+ Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
661
+ Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
662
+ Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
663
+
664
+ FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
665
+ FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
666
+ FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
667
+
668
+ return Mem256, Mem128, Mem64
669
+
670
+ def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
671
+ le_256_q = fs_in['le_256_q']
672
+ re_256_q = fs_in['re_256_q']
673
+ mo_256_q = fs_in['mo_256_q']
674
+
675
+ le_128_q = fs_in['le_128_q']
676
+ re_128_q = fs_in['re_128_q']
677
+ mo_128_q = fs_in['mo_128_q']
678
+
679
+ le_64_q = fs_in['le_64_q']
680
+ re_64_q = fs_in['re_64_q']
681
+ mo_64_q = fs_in['mo_64_q']
682
+
683
+
684
+ ####for 256
685
+ le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
686
+ re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
687
+ mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
688
+
689
+ le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
690
+ re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
691
+ mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
692
+
693
+ le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
694
+ re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
695
+ mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
696
+
697
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
698
+ le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
699
+ re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
700
+ mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
701
+ le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
702
+ le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
703
+ re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
704
+ re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
705
+ mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
706
+ mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
707
+
708
+ le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
709
+ re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
710
+ mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
711
+ le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
712
+ le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
713
+ re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
714
+ re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
715
+ mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
716
+ mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
717
+
718
+ le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
719
+ re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
720
+ mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
721
+ le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
722
+ le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
723
+ re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
724
+ re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
725
+ mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
726
+ mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
727
+ else:
728
+ le_256_mem = le_256_mem_g
729
+ re_256_mem = re_256_mem_g
730
+ mo_256_mem = mo_256_mem_g
731
+ le_128_mem = le_128_mem_g
732
+ re_128_mem = re_128_mem_g
733
+ mo_128_mem = mo_128_mem_g
734
+ le_64_mem = le_64_mem_g
735
+ re_64_mem = re_64_mem_g
736
+ mo_64_mem = mo_64_mem_g
737
+
738
+ le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
739
+ re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
740
+ mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
741
+
742
+ ####for 128
743
+ le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
744
+ re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
745
+ mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
746
+
747
+ ####for 64
748
+ le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
749
+ re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
750
+ mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
751
+
752
+
753
+ EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
754
+ EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
755
+ EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
756
+ Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
757
+ Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
758
+ Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
759
+ return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
760
+
761
+ def reconstruct(self, fs_in, locs, memstar):
762
+ le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
763
+ le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
764
+ le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
765
+
766
+ le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
767
+ re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
768
+ mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
769
+
770
+ le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
771
+ re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
772
+ mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
773
+
774
+ le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
775
+ re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
776
+ mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
777
+
778
+
779
+ le_location = locs[:,0,:]
780
+ re_location = locs[:,1,:]
781
+ mo_location = locs[:,3,:]
782
+ le_location = le_location.cpu().int().numpy()
783
+ re_location = re_location.cpu().int().numpy()
784
+ mo_location = mo_location.cpu().int().numpy()
785
+
786
+ up_in_256 = fs_in['f256'].clone()# * 0
787
+ up_in_128 = fs_in['f128'].clone()# * 0
788
+ up_in_64 = fs_in['f64'].clone()# * 0
789
+
790
+ for i in range(fs_in['f256'].size(0)):
791
+ up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
792
+ up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
793
+ up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
794
+
795
+ up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
796
+ up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
797
+ up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
798
+
799
+ up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
800
+ up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
801
+ up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
802
+
803
+ ms_in_64 = self.MSDilate(fs_in['f64'].clone())
804
+ fea_up1 = self.up1(ms_in_64, up_in_64)
805
+ fea_up2 = self.up2(fea_up1, up_in_128) #
806
+ fea_up3 = self.up3(fea_up2, up_in_256) #
807
+ output = self.up4(fea_up3) #
808
+ return output
809
+
810
+ def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
811
+ return self.memorize(sp_imgs, sp_locs)
812
+
813
+ def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
814
+ fs_in = self.E_lq(lq, loc) # low quality images
815
+ GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
816
+ GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
817
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
818
+ GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
819
+ GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
820
+ else:
821
+ GSOut = None
822
+ return GeOut, GSOut
823
+
824
+ class UpResBlock(nn.Module):
825
+ def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
826
+ super(UpResBlock, self).__init__()
827
+ self.Model = nn.Sequential(
828
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
829
+ nn.LeakyReLU(0.2),
830
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
831
+ )
832
+ def forward(self, x):
833
+ out = x + self.Model(x)
834
+ return out
835
+
plugins/plugin_faceswap.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chain_img_processor import ChainImgProcessor, ChainImgPlugin
2
+ from roop.face_helper import get_one_face, get_many_faces, swap_face
3
+ import os
4
+ from roop.utilities import compute_cosine_distance
5
+
6
+ modname = os.path.basename(__file__)[:-3] # calculating modname
7
+
8
+ # start function
9
+ def start(core:ChainImgProcessor):
10
+ manifest = { # plugin settings
11
+ "name": "Faceswap", # name
12
+ "version": "1.0", # version
13
+
14
+ "default_options": {
15
+ "swap_mode": "selected",
16
+ "max_distance": 0.65, # max distance to detect face similarity
17
+ },
18
+ "img_processor": {
19
+ "faceswap": Faceswap
20
+ }
21
+ }
22
+ return manifest
23
+
24
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
25
+ pass
26
+
27
+
28
+ class Faceswap(ChainImgPlugin):
29
+
30
+ def init_plugin(self):
31
+ pass
32
+
33
+
34
+ def process(self, frame, params:dict):
35
+ if not "input_face_datas" in params or len(params["input_face_datas"]) < 1:
36
+ params["face_detected"] = False
37
+ return frame
38
+
39
+ temp_frame = frame
40
+ params["face_detected"] = True
41
+ params["processed_faces"] = []
42
+
43
+ if params["swap_mode"] == "first":
44
+ face = get_one_face(frame)
45
+ if face is None:
46
+ params["face_detected"] = False
47
+ return frame
48
+ params["processed_faces"].append(face)
49
+ frame = swap_face(params["input_face_datas"][0], face, frame)
50
+ return frame
51
+
52
+ else:
53
+ faces = get_many_faces(frame)
54
+ if(len(faces) < 1):
55
+ params["face_detected"] = False
56
+ return frame
57
+
58
+ dist_threshold = params["face_distance_threshold"]
59
+
60
+ if params["swap_mode"] == "all":
61
+ for sf in params["input_face_datas"]:
62
+ for face in faces:
63
+ params["processed_faces"].append(face)
64
+ temp_frame = swap_face(sf, face, temp_frame)
65
+ return temp_frame
66
+
67
+ elif params["swap_mode"] == "selected":
68
+ for i,tf in enumerate(params["target_face_datas"]):
69
+ for face in faces:
70
+ if compute_cosine_distance(tf.embedding, face.embedding) <= dist_threshold:
71
+ temp_frame = swap_face(params["input_face_datas"][i], face, temp_frame)
72
+ params["processed_faces"].append(face)
73
+ break
74
+
75
+ elif params["swap_mode"] == "all_female" or params["swap_mode"] == "all_male":
76
+ gender = 'F' if params["swap_mode"] == "all_female" else 'M'
77
+ face_found = False
78
+ for face in faces:
79
+ if face.sex == gender:
80
+ face_found = True
81
+ if face_found:
82
+ params["processed_faces"].append(face)
83
+ temp_frame = swap_face(params["input_face_datas"][0], face, temp_frame)
84
+ face_found = False
85
+
86
+ return temp_frame
plugins/plugin_gfpgan.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from chain_img_processor import ChainImgProcessor, ChainImgPlugin
2
+ import os
3
+ import gfpgan
4
+ import threading
5
+ from PIL import Image
6
+ from numpy import asarray
7
+ import cv2
8
+
9
+ from roop.utilities import resolve_relative_path, conditional_download
10
+ modname = os.path.basename(__file__)[:-3] # calculating modname
11
+
12
+ model_gfpgan = None
13
+ THREAD_LOCK_GFPGAN = threading.Lock()
14
+
15
+
16
+ # start function
17
+ def start(core:ChainImgProcessor):
18
+ manifest = { # plugin settings
19
+ "name": "GFPGAN", # name
20
+ "version": "1.4", # version
21
+
22
+ "default_options": {},
23
+ "img_processor": {
24
+ "gfpgan": GFPGAN
25
+ }
26
+ }
27
+ return manifest
28
+
29
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
30
+ pass
31
+
32
+
33
+ class GFPGAN(ChainImgPlugin):
34
+
35
+ def init_plugin(self):
36
+ global model_gfpgan
37
+
38
+ if model_gfpgan is None:
39
+ model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
40
+ model_gfpgan = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=self.device) # type: ignore[attr-defined]
41
+
42
+
43
+
44
+ def process(self, frame, params:dict):
45
+ import copy
46
+
47
+ global model_gfpgan
48
+
49
+ if model_gfpgan is None:
50
+ return frame
51
+
52
+ if "face_detected" in params:
53
+ if not params["face_detected"]:
54
+ return frame
55
+ # don't touch original
56
+ temp_frame = copy.copy(frame)
57
+ if "processed_faces" in params:
58
+ for face in params["processed_faces"]:
59
+ start_x, start_y, end_x, end_y = map(int, face['bbox'])
60
+ padding_x = int((end_x - start_x) * 0.5)
61
+ padding_y = int((end_y - start_y) * 0.5)
62
+ start_x = max(0, start_x - padding_x)
63
+ start_y = max(0, start_y - padding_y)
64
+ end_x = max(0, end_x + padding_x)
65
+ end_y = max(0, end_y + padding_y)
66
+ temp_face = temp_frame[start_y:end_y, start_x:end_x]
67
+ if temp_face.size:
68
+ with THREAD_LOCK_GFPGAN:
69
+ _, _, temp_face = model_gfpgan.enhance(
70
+ temp_face,
71
+ paste_back=True
72
+ )
73
+ temp_frame[start_y:end_y, start_x:end_x] = temp_face
74
+ else:
75
+ with THREAD_LOCK_GFPGAN:
76
+ _, _, temp_frame = model_gfpgan.enhance(
77
+ temp_frame,
78
+ paste_back=True
79
+ )
80
+
81
+ if not "blend_ratio" in params:
82
+ return temp_frame
83
+
84
+ temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
85
+ return asarray(temp_frame)
plugins/plugin_txt2clip.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import threading
6
+ from chain_img_processor import ChainImgProcessor, ChainImgPlugin
7
+ from torchvision import transforms
8
+ from clip.clipseg import CLIPDensePredT
9
+ from numpy import asarray
10
+
11
+
12
+ THREAD_LOCK_CLIP = threading.Lock()
13
+
14
+ modname = os.path.basename(__file__)[:-3] # calculating modname
15
+
16
+ model_clip = None
17
+
18
+
19
+
20
+
21
+ # start function
22
+ def start(core:ChainImgProcessor):
23
+ manifest = { # plugin settings
24
+ "name": "Text2Clip", # name
25
+ "version": "1.0", # version
26
+
27
+ "default_options": {
28
+ },
29
+ "img_processor": {
30
+ "txt2clip": Text2Clip
31
+ }
32
+ }
33
+ return manifest
34
+
35
+ def start_with_options(core:ChainImgProcessor, manifest:dict):
36
+ pass
37
+
38
+
39
+
40
+ class Text2Clip(ChainImgPlugin):
41
+
42
+ def load_clip_model(self):
43
+ global model_clip
44
+
45
+ if model_clip is None:
46
+ device = torch.device(super().device)
47
+ model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
48
+ model_clip.eval();
49
+ model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
50
+ model_clip.to(device)
51
+
52
+
53
+ def init_plugin(self):
54
+ self.load_clip_model()
55
+
56
+ def process(self, frame, params:dict):
57
+ if "face_detected" in params:
58
+ if not params["face_detected"]:
59
+ return frame
60
+
61
+ return self.mask_original(params["original_frame"], frame, params["clip_prompt"])
62
+
63
+
64
+ def mask_original(self, img1, img2, keywords):
65
+ global model_clip
66
+
67
+ source_image_small = cv2.resize(img1, (256,256))
68
+
69
+ img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
70
+ mask_border = 1
71
+ l = 0
72
+ t = 0
73
+ r = 1
74
+ b = 1
75
+
76
+ mask_blur = 5
77
+ clip_blur = 5
78
+
79
+ img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
80
+ (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
81
+ img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
82
+ img_mask /= 255
83
+
84
+
85
+ input_image = source_image_small
86
+
87
+ transform = transforms.Compose([
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
90
+ transforms.Resize((256, 256)),
91
+ ])
92
+ img = transform(input_image).unsqueeze(0)
93
+
94
+ thresh = 0.5
95
+ prompts = keywords.split(',')
96
+ with THREAD_LOCK_CLIP:
97
+ with torch.no_grad():
98
+ preds = model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
99
+ clip_mask = torch.sigmoid(preds[0][0])
100
+ for i in range(len(prompts)-1):
101
+ clip_mask += torch.sigmoid(preds[i+1][0])
102
+
103
+ clip_mask = clip_mask.data.cpu().numpy()
104
+ np.clip(clip_mask, 0, 1)
105
+
106
+ clip_mask[clip_mask>thresh] = 1.0
107
+ clip_mask[clip_mask<=thresh] = 0.0
108
+ kernel = np.ones((5, 5), np.float32)
109
+ clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
110
+ clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
111
+
112
+ img_mask *= clip_mask
113
+ img_mask[img_mask<0.0] = 0.0
114
+
115
+ img_mask = cv2.resize(img_mask, (img2.shape[1], img2.shape[0]))
116
+ img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
117
+
118
+ target = img2.astype(np.float32)
119
+ result = (1-img_mask) * target
120
+ result += img_mask * img1.astype(np.float32)
121
+ return np.uint8(result)
122
+
roop-unleashed.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1abfc4e80fb1e9e8eb3381f1d46193051b683ea452595a189bb5d647dfe7b6b
3
+ size 5953