BlackBeenie
commited on
Commit
·
fe3fdf0
1
Parent(s):
8dd41a8
feat: Add plugins
Browse files- mypy.ini +7 -0
- plugin_options/core.json +6 -0
- plugin_options/core_video.json +5 -0
- plugin_options/plugin_codeformer.json +8 -0
- plugin_options/plugin_dmdnet.json +3 -0
- plugin_options/plugin_faceswap.json +5 -0
- plugin_options/plugin_gfpgan.json +3 -0
- plugin_options/plugin_txt2clip.json +3 -0
- plugins/codeformer_app_cv2.py +300 -0
- plugins/codeformer_face_helper_cv2.py +94 -0
- plugins/core.py +29 -0
- plugins/core_video.py +26 -0
- plugins/plugin_codeformer.py +83 -0
- plugins/plugin_dmdnet.py +835 -0
- plugins/plugin_faceswap.py +86 -0
- plugins/plugin_gfpgan.py +85 -0
- plugins/plugin_txt2clip.py +122 -0
- roop-unleashed.ipynb +3 -0
mypy.ini
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[mypy]
|
2 |
+
check_untyped_defs = True
|
3 |
+
disallow_any_generics = True
|
4 |
+
disallow_untyped_calls = True
|
5 |
+
disallow_untyped_defs = True
|
6 |
+
ignore_missing_imports = True
|
7 |
+
strict_optional = False
|
plugin_options/core.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"default_chain": "faceswap",
|
3 |
+
"init_on_start": "faceswap,dmdnet,gfpgan,codeformer",
|
4 |
+
"is_demo_row_render": false,
|
5 |
+
"v": "2.0"
|
6 |
+
}
|
plugin_options/core_video.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"v": "2.0",
|
3 |
+
"video_save_codec": "libx264",
|
4 |
+
"video_save_crf": 14
|
5 |
+
}
|
plugin_options/plugin_codeformer.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"background_enhance": false,
|
3 |
+
"codeformer_fidelity": 0.8,
|
4 |
+
"face_upsample": true,
|
5 |
+
"skip_if_no_face": true,
|
6 |
+
"upscale": 1,
|
7 |
+
"v": "3.0"
|
8 |
+
}
|
plugin_options/plugin_dmdnet.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"v": "1.0"
|
3 |
+
}
|
plugin_options/plugin_faceswap.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_distance": 0.65,
|
3 |
+
"swap_mode": "selected",
|
4 |
+
"v": "1.0"
|
5 |
+
}
|
plugin_options/plugin_gfpgan.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"v": "1.4"
|
3 |
+
}
|
plugin_options/plugin_txt2clip.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"v": "1.0"
|
3 |
+
}
|
plugins/codeformer_app_cv2.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified version from codeformer-pip project
|
3 |
+
|
4 |
+
S-Lab License 1.0
|
5 |
+
|
6 |
+
Copyright 2022 S-Lab
|
7 |
+
|
8 |
+
https://github.com/kadirnar/codeformer-pip/blob/main/LICENSE
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
import cv2
|
14 |
+
import torch
|
15 |
+
from codeformer.facelib.detection import init_detection_model
|
16 |
+
from codeformer.facelib.parsing import init_parsing_model
|
17 |
+
from torchvision.transforms.functional import normalize
|
18 |
+
|
19 |
+
from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet
|
20 |
+
from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img
|
21 |
+
from codeformer.basicsr.utils.download_util import load_file_from_url
|
22 |
+
from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer
|
23 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
24 |
+
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
25 |
+
from codeformer.facelib.utils.misc import is_gray
|
26 |
+
import threading
|
27 |
+
|
28 |
+
from plugins.codeformer_face_helper_cv2 import FaceRestoreHelperOptimized
|
29 |
+
|
30 |
+
THREAD_LOCK_FACE_HELPER = threading.Lock()
|
31 |
+
THREAD_LOCK_FACE_HELPER_CREATE = threading.Lock()
|
32 |
+
THREAD_LOCK_FACE_HELPER_PROCERSSING = threading.Lock()
|
33 |
+
THREAD_LOCK_CODEFORMER_NET = threading.Lock()
|
34 |
+
THREAD_LOCK_CODEFORMER_NET_CREATE = threading.Lock()
|
35 |
+
THREAD_LOCK_BGUPSAMPLER = threading.Lock()
|
36 |
+
|
37 |
+
pretrain_model_url = {
|
38 |
+
"codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
39 |
+
"detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
40 |
+
"parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
|
41 |
+
"realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
|
42 |
+
}
|
43 |
+
|
44 |
+
# download weights
|
45 |
+
if not os.path.exists("models/CodeFormer/codeformer.pth"):
|
46 |
+
load_file_from_url(
|
47 |
+
url=pretrain_model_url["codeformer"], model_dir="models/CodeFormer/", progress=True, file_name=None
|
48 |
+
)
|
49 |
+
if not os.path.exists("models/CodeFormer/facelib/detection_Resnet50_Final.pth"):
|
50 |
+
load_file_from_url(
|
51 |
+
url=pretrain_model_url["detection"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
|
52 |
+
)
|
53 |
+
if not os.path.exists("models/CodeFormer/facelib/parsing_parsenet.pth"):
|
54 |
+
load_file_from_url(
|
55 |
+
url=pretrain_model_url["parsing"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
|
56 |
+
)
|
57 |
+
if not os.path.exists("models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth"):
|
58 |
+
load_file_from_url(
|
59 |
+
url=pretrain_model_url["realesrgan"], model_dir="models/CodeFormer/realesrgan", progress=True, file_name=None
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def imread(img_path):
|
64 |
+
img = cv2.imread(img_path)
|
65 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
66 |
+
return img
|
67 |
+
|
68 |
+
|
69 |
+
# set enhancer with RealESRGAN
|
70 |
+
def set_realesrgan():
|
71 |
+
half = True if torch.cuda.is_available() else False
|
72 |
+
model = RRDBNet(
|
73 |
+
num_in_ch=3,
|
74 |
+
num_out_ch=3,
|
75 |
+
num_feat=64,
|
76 |
+
num_block=23,
|
77 |
+
num_grow_ch=32,
|
78 |
+
scale=2,
|
79 |
+
)
|
80 |
+
upsampler = RealESRGANer(
|
81 |
+
scale=2,
|
82 |
+
model_path="models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth",
|
83 |
+
model=model,
|
84 |
+
tile=400,
|
85 |
+
tile_pad=40,
|
86 |
+
pre_pad=0,
|
87 |
+
half=half,
|
88 |
+
)
|
89 |
+
return upsampler
|
90 |
+
|
91 |
+
|
92 |
+
upsampler = set_realesrgan()
|
93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
+
|
95 |
+
codeformers_cache = []
|
96 |
+
|
97 |
+
def get_codeformer():
|
98 |
+
if len(codeformers_cache) > 0:
|
99 |
+
with THREAD_LOCK_CODEFORMER_NET:
|
100 |
+
if len(codeformers_cache) > 0:
|
101 |
+
return codeformers_cache.pop()
|
102 |
+
|
103 |
+
with THREAD_LOCK_CODEFORMER_NET_CREATE:
|
104 |
+
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
105 |
+
dim_embd=512,
|
106 |
+
codebook_size=1024,
|
107 |
+
n_head=8,
|
108 |
+
n_layers=9,
|
109 |
+
connect_list=["32", "64", "128", "256"],
|
110 |
+
).to(device)
|
111 |
+
ckpt_path = "models/CodeFormer/codeformer.pth"
|
112 |
+
checkpoint = torch.load(ckpt_path)["params_ema"]
|
113 |
+
codeformer_net.load_state_dict(checkpoint)
|
114 |
+
codeformer_net.eval()
|
115 |
+
return codeformer_net
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def release_codeformer(codeformer):
|
120 |
+
with THREAD_LOCK_CODEFORMER_NET:
|
121 |
+
codeformers_cache.append(codeformer)
|
122 |
+
|
123 |
+
#os.makedirs("output", exist_ok=True)
|
124 |
+
|
125 |
+
# ------- face restore thread cache ----------
|
126 |
+
|
127 |
+
face_restore_helper_cache = []
|
128 |
+
|
129 |
+
detection_model = "retinaface_resnet50"
|
130 |
+
|
131 |
+
inited_face_restore_helper_nn = False
|
132 |
+
|
133 |
+
import time
|
134 |
+
|
135 |
+
def get_face_restore_helper(upscale):
|
136 |
+
global inited_face_restore_helper_nn
|
137 |
+
with THREAD_LOCK_FACE_HELPER:
|
138 |
+
face_helper = FaceRestoreHelperOptimized(
|
139 |
+
upscale,
|
140 |
+
face_size=512,
|
141 |
+
crop_ratio=(1, 1),
|
142 |
+
det_model=detection_model,
|
143 |
+
save_ext="png",
|
144 |
+
use_parse=True,
|
145 |
+
device=device,
|
146 |
+
)
|
147 |
+
#return face_helper
|
148 |
+
|
149 |
+
if inited_face_restore_helper_nn:
|
150 |
+
while len(face_restore_helper_cache) == 0:
|
151 |
+
time.sleep(0.05)
|
152 |
+
face_detector, face_parse = face_restore_helper_cache.pop()
|
153 |
+
face_helper.face_detector = face_detector
|
154 |
+
face_helper.face_parse = face_parse
|
155 |
+
return face_helper
|
156 |
+
else:
|
157 |
+
inited_face_restore_helper_nn = True
|
158 |
+
face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
|
159 |
+
face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
|
160 |
+
return face_helper
|
161 |
+
|
162 |
+
def get_face_restore_helper2(upscale): # still not work well!!!
|
163 |
+
face_helper = FaceRestoreHelperOptimized(
|
164 |
+
upscale,
|
165 |
+
face_size=512,
|
166 |
+
crop_ratio=(1, 1),
|
167 |
+
det_model=detection_model,
|
168 |
+
save_ext="png",
|
169 |
+
use_parse=True,
|
170 |
+
device=device,
|
171 |
+
)
|
172 |
+
#return face_helper
|
173 |
+
|
174 |
+
if len(face_restore_helper_cache) > 0:
|
175 |
+
with THREAD_LOCK_FACE_HELPER:
|
176 |
+
if len(face_restore_helper_cache) > 0:
|
177 |
+
face_detector, face_parse = face_restore_helper_cache.pop()
|
178 |
+
face_helper.face_detector = face_detector
|
179 |
+
face_helper.face_parse = face_parse
|
180 |
+
return face_helper
|
181 |
+
|
182 |
+
with THREAD_LOCK_FACE_HELPER_CREATE:
|
183 |
+
face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
|
184 |
+
face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
|
185 |
+
return face_helper
|
186 |
+
|
187 |
+
def release_face_restore_helper(face_helper):
|
188 |
+
#return
|
189 |
+
#with THREAD_LOCK_FACE_HELPER:
|
190 |
+
face_restore_helper_cache.append((face_helper.face_detector, face_helper.face_parse))
|
191 |
+
#pass
|
192 |
+
|
193 |
+
def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity, skip_if_no_face = False):
|
194 |
+
# take the default setting for the demo
|
195 |
+
has_aligned = False
|
196 |
+
only_center_face = False
|
197 |
+
draw_box = False
|
198 |
+
|
199 |
+
#print("Inp:", image, background_enhance, face_upsample, upscale, codeformer_fidelity)
|
200 |
+
if isinstance(image, str):
|
201 |
+
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
202 |
+
else:
|
203 |
+
img = image
|
204 |
+
#print("\timage size:", img.shape)
|
205 |
+
|
206 |
+
upscale = int(upscale) # convert type to int
|
207 |
+
if upscale > 4: # avoid memory exceeded due to too large upscale
|
208 |
+
upscale = 4
|
209 |
+
if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution
|
210 |
+
upscale = 2
|
211 |
+
if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
|
212 |
+
upscale = 1
|
213 |
+
background_enhance = False
|
214 |
+
#face_upsample = False
|
215 |
+
|
216 |
+
face_helper = get_face_restore_helper(upscale)
|
217 |
+
|
218 |
+
bg_upsampler = upsampler if background_enhance else None
|
219 |
+
face_upsampler = upsampler if face_upsample else None
|
220 |
+
|
221 |
+
if has_aligned:
|
222 |
+
# the input faces are already cropped and aligned
|
223 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
224 |
+
face_helper.is_gray = is_gray(img, threshold=5)
|
225 |
+
if face_helper.is_gray:
|
226 |
+
print("\tgrayscale input: True")
|
227 |
+
face_helper.cropped_faces = [img]
|
228 |
+
else:
|
229 |
+
with THREAD_LOCK_FACE_HELPER_PROCERSSING:
|
230 |
+
face_helper.read_image(img)
|
231 |
+
# get face landmarks for each face
|
232 |
+
|
233 |
+
num_det_faces = face_helper.get_face_landmarks_5(
|
234 |
+
only_center_face=only_center_face, resize=640, eye_dist_threshold=5
|
235 |
+
)
|
236 |
+
#print(f"\tdetect {num_det_faces} faces")
|
237 |
+
|
238 |
+
if num_det_faces == 0 and skip_if_no_face:
|
239 |
+
release_face_restore_helper(face_helper)
|
240 |
+
return img
|
241 |
+
|
242 |
+
# align and warp each face
|
243 |
+
face_helper.align_warp_face()
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
# face restoration for each cropped face
|
248 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
249 |
+
# prepare data
|
250 |
+
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
251 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
252 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
253 |
+
|
254 |
+
codeformer_net = get_codeformer()
|
255 |
+
try:
|
256 |
+
with torch.no_grad():
|
257 |
+
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
|
258 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
259 |
+
del output
|
260 |
+
except RuntimeError as error:
|
261 |
+
print(f"Failed inference for CodeFormer: {error}")
|
262 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
263 |
+
release_codeformer(codeformer_net)
|
264 |
+
|
265 |
+
restored_face = restored_face.astype("uint8")
|
266 |
+
face_helper.add_restored_face(restored_face)
|
267 |
+
|
268 |
+
# paste_back
|
269 |
+
if not has_aligned:
|
270 |
+
# upsample the background
|
271 |
+
if bg_upsampler is not None:
|
272 |
+
with THREAD_LOCK_BGUPSAMPLER:
|
273 |
+
# Now only support RealESRGAN for upsampling background
|
274 |
+
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
|
275 |
+
else:
|
276 |
+
bg_img = None
|
277 |
+
face_helper.get_inverse_affine(None)
|
278 |
+
# paste each restored face to the input image
|
279 |
+
if face_upsample and face_upsampler is not None:
|
280 |
+
restored_img = face_helper.paste_faces_to_input_image(
|
281 |
+
upsample_img=bg_img,
|
282 |
+
draw_box=draw_box,
|
283 |
+
face_upsampler=face_upsampler,
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
|
287 |
+
|
288 |
+
if image.shape != restored_img.shape:
|
289 |
+
h, w, _ = image.shape
|
290 |
+
restored_img = cv2.resize(restored_img, (w, h), interpolation=cv2.INTER_LINEAR)
|
291 |
+
|
292 |
+
|
293 |
+
release_face_restore_helper(face_helper)
|
294 |
+
# save restored img
|
295 |
+
if isinstance(image, str):
|
296 |
+
save_path = f"output/out.png"
|
297 |
+
imwrite(restored_img, str(save_path))
|
298 |
+
return save_path
|
299 |
+
else:
|
300 |
+
return restored_img
|
plugins/codeformer_face_helper_cv2.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from codeformer.basicsr.utils.misc import get_device
|
5 |
+
|
6 |
+
class FaceRestoreHelperOptimized(FaceRestoreHelper):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
upscale_factor,
|
10 |
+
face_size=512,
|
11 |
+
crop_ratio=(1, 1),
|
12 |
+
det_model="retinaface_resnet50",
|
13 |
+
save_ext="png",
|
14 |
+
template_3points=False,
|
15 |
+
pad_blur=False,
|
16 |
+
use_parse=False,
|
17 |
+
device=None,
|
18 |
+
):
|
19 |
+
self.template_3points = template_3points # improve robustness
|
20 |
+
self.upscale_factor = int(upscale_factor)
|
21 |
+
# the cropped face ratio based on the square face
|
22 |
+
self.crop_ratio = crop_ratio # (h, w)
|
23 |
+
assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1"
|
24 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
25 |
+
self.det_model = det_model
|
26 |
+
|
27 |
+
if self.det_model == "dlib":
|
28 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
29 |
+
self.face_template = np.array(
|
30 |
+
[
|
31 |
+
[686.77227723, 488.62376238],
|
32 |
+
[586.77227723, 493.59405941],
|
33 |
+
[337.91089109, 488.38613861],
|
34 |
+
[437.95049505, 493.51485149],
|
35 |
+
[513.58415842, 678.5049505],
|
36 |
+
]
|
37 |
+
)
|
38 |
+
self.face_template = self.face_template / (1024 // face_size)
|
39 |
+
elif self.template_3points:
|
40 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
41 |
+
else:
|
42 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
43 |
+
# facexlib
|
44 |
+
self.face_template = np.array(
|
45 |
+
[
|
46 |
+
[192.98138, 239.94708],
|
47 |
+
[318.90277, 240.1936],
|
48 |
+
[256.63416, 314.01935],
|
49 |
+
[201.26117, 371.41043],
|
50 |
+
[313.08905, 371.15118],
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
55 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
56 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
57 |
+
|
58 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
59 |
+
if self.crop_ratio[0] > 1:
|
60 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
61 |
+
if self.crop_ratio[1] > 1:
|
62 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
63 |
+
self.save_ext = save_ext
|
64 |
+
self.pad_blur = pad_blur
|
65 |
+
if self.pad_blur is True:
|
66 |
+
self.template_3points = False
|
67 |
+
|
68 |
+
self.all_landmarks_5 = []
|
69 |
+
self.det_faces = []
|
70 |
+
self.affine_matrices = []
|
71 |
+
self.inverse_affine_matrices = []
|
72 |
+
self.cropped_faces = []
|
73 |
+
self.restored_faces = []
|
74 |
+
self.pad_input_imgs = []
|
75 |
+
|
76 |
+
if device is None:
|
77 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
78 |
+
self.device = get_device()
|
79 |
+
else:
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
# init face detection model
|
83 |
+
# if self.det_model == "dlib":
|
84 |
+
# self.face_detector, self.shape_predictor_5 = self.init_dlib(
|
85 |
+
# dlib_model_url["face_detector"], dlib_model_url["shape_predictor_5"]
|
86 |
+
# )
|
87 |
+
# else:
|
88 |
+
# self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
89 |
+
|
90 |
+
# init face parsing model
|
91 |
+
self.use_parse = use_parse
|
92 |
+
#self.face_parse = init_parsing_model(model_name="parsenet", device=self.device)
|
93 |
+
|
94 |
+
# MUST set face_detector and face_parse!!!
|
plugins/core.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core plugin
|
2 |
+
# author: Vladislav Janvarev
|
3 |
+
|
4 |
+
from chain_img_processor import ChainImgProcessor
|
5 |
+
|
6 |
+
# start function
|
7 |
+
def start(core:ChainImgProcessor):
|
8 |
+
manifest = {
|
9 |
+
"name": "Core plugin",
|
10 |
+
"version": "2.0",
|
11 |
+
|
12 |
+
"default_options": {
|
13 |
+
"default_chain": "faceswap", # default chain to run
|
14 |
+
"init_on_start": "faceswap,txt2clip,gfpgan,codeformer", # init these processors on start
|
15 |
+
"is_demo_row_render": False,
|
16 |
+
},
|
17 |
+
|
18 |
+
}
|
19 |
+
return manifest
|
20 |
+
|
21 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
22 |
+
options = manifest["options"]
|
23 |
+
|
24 |
+
core.default_chain = options["default_chain"]
|
25 |
+
core.init_on_start = options["init_on_start"]
|
26 |
+
|
27 |
+
core.is_demo_row_render= options["is_demo_row_render"]
|
28 |
+
|
29 |
+
return manifest
|
plugins/core_video.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core plugin
|
2 |
+
# author: Vladislav Janvarev
|
3 |
+
|
4 |
+
from chain_img_processor import ChainImgProcessor, ChainVideoProcessor
|
5 |
+
|
6 |
+
# start function
|
7 |
+
def start(core:ChainImgProcessor):
|
8 |
+
manifest = {
|
9 |
+
"name": "Core video plugin",
|
10 |
+
"version": "2.0",
|
11 |
+
|
12 |
+
"default_options": {
|
13 |
+
"video_save_codec": "libx264", # default codec to save
|
14 |
+
"video_save_crf": 14, # default crf to save
|
15 |
+
},
|
16 |
+
|
17 |
+
}
|
18 |
+
return manifest
|
19 |
+
|
20 |
+
def start_with_options(core:ChainVideoProcessor, manifest:dict):
|
21 |
+
options = manifest["options"]
|
22 |
+
|
23 |
+
core.video_save_codec = options["video_save_codec"]
|
24 |
+
core.video_save_crf = options["video_save_crf"]
|
25 |
+
|
26 |
+
return manifest
|
plugins/plugin_codeformer.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Codeformer enchance plugin
|
2 |
+
# author: Vladislav Janvarev
|
3 |
+
|
4 |
+
# CountFloyd 20230717, extended to blend original/destination images
|
5 |
+
|
6 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
from numpy import asarray
|
10 |
+
|
11 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
12 |
+
|
13 |
+
# start function
|
14 |
+
def start(core:ChainImgProcessor):
|
15 |
+
manifest = { # plugin settings
|
16 |
+
"name": "Codeformer", # name
|
17 |
+
"version": "3.0", # version
|
18 |
+
|
19 |
+
"default_options": {
|
20 |
+
"background_enhance": True, #
|
21 |
+
"face_upsample": True, #
|
22 |
+
"upscale": 2, #
|
23 |
+
"codeformer_fidelity": 0.8,
|
24 |
+
"skip_if_no_face":False,
|
25 |
+
|
26 |
+
},
|
27 |
+
|
28 |
+
"img_processor": {
|
29 |
+
"codeformer": PluginCodeformer # 1 function - init, 2 - process
|
30 |
+
}
|
31 |
+
}
|
32 |
+
return manifest
|
33 |
+
|
34 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
35 |
+
pass
|
36 |
+
|
37 |
+
class PluginCodeformer(ChainImgPlugin):
|
38 |
+
def init_plugin(self):
|
39 |
+
import plugins.codeformer_app_cv2
|
40 |
+
pass
|
41 |
+
|
42 |
+
def process(self, img, params:dict):
|
43 |
+
import copy
|
44 |
+
|
45 |
+
# params can be used to transfer some img info to next processors
|
46 |
+
from plugins.codeformer_app_cv2 import inference_app
|
47 |
+
options = self.core.plugin_options(modname)
|
48 |
+
|
49 |
+
if "face_detected" in params:
|
50 |
+
if not params["face_detected"]:
|
51 |
+
return img
|
52 |
+
|
53 |
+
# don't touch original
|
54 |
+
temp_frame = copy.copy(img)
|
55 |
+
if "processed_faces" in params:
|
56 |
+
for face in params["processed_faces"]:
|
57 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
58 |
+
padding_x = int((end_x - start_x) * 0.5)
|
59 |
+
padding_y = int((end_y - start_y) * 0.5)
|
60 |
+
start_x = max(0, start_x - padding_x)
|
61 |
+
start_y = max(0, start_y - padding_y)
|
62 |
+
end_x = max(0, end_x + padding_x)
|
63 |
+
end_y = max(0, end_y + padding_y)
|
64 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
65 |
+
if temp_face.size:
|
66 |
+
temp_face = inference_app(temp_face, options.get("background_enhance"), options.get("face_upsample"),
|
67 |
+
options.get("upscale"), options.get("codeformer_fidelity"),
|
68 |
+
options.get("skip_if_no_face"))
|
69 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
70 |
+
else:
|
71 |
+
temp_frame = inference_app(temp_frame, options.get("background_enhance"), options.get("face_upsample"),
|
72 |
+
options.get("upscale"), options.get("codeformer_fidelity"),
|
73 |
+
options.get("skip_if_no_face"))
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
if not "blend_ratio" in params:
|
78 |
+
return temp_frame
|
79 |
+
|
80 |
+
|
81 |
+
temp_frame = Image.blend(Image.fromarray(img), Image.fromarray(temp_frame), params["blend_ratio"])
|
82 |
+
return asarray(temp_frame)
|
83 |
+
|
plugins/plugin_dmdnet.py
ADDED
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from numpy import asarray
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import scipy.io as sio
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn.utils.spectral_norm as SpectralNorm
|
12 |
+
from torchvision.ops import roi_align
|
13 |
+
|
14 |
+
from math import sqrt
|
15 |
+
import os
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
import os
|
19 |
+
from torchvision.transforms.functional import normalize
|
20 |
+
import copy
|
21 |
+
import threading
|
22 |
+
|
23 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
24 |
+
|
25 |
+
oDMDNet = None
|
26 |
+
device = None
|
27 |
+
|
28 |
+
THREAD_LOCK_DMDNET = threading.Lock()
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
# start function
|
33 |
+
def start(core:ChainImgProcessor):
|
34 |
+
manifest = { # plugin settings
|
35 |
+
"name": "DMDNet", # name
|
36 |
+
"version": "1.0", # version
|
37 |
+
|
38 |
+
"default_options": {},
|
39 |
+
"img_processor": {
|
40 |
+
"dmdnet": DMDNETPlugin
|
41 |
+
}
|
42 |
+
}
|
43 |
+
return manifest
|
44 |
+
|
45 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
class DMDNETPlugin(ChainImgPlugin):
|
50 |
+
|
51 |
+
# https://stackoverflow.com/a/67174339
|
52 |
+
def landmarks106_to_68(self, pt106):
|
53 |
+
map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
|
54 |
+
43,48,49,51,50,
|
55 |
+
102,103,104,105,101,
|
56 |
+
72,73,74,86,78,79,80,85,84,
|
57 |
+
35,41,42,39,37,36,
|
58 |
+
89,95,96,93,91,90,
|
59 |
+
52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
|
60 |
+
]
|
61 |
+
|
62 |
+
pt68 = []
|
63 |
+
for i in range(68):
|
64 |
+
index = map106to68[i]
|
65 |
+
pt68.append(pt106[index])
|
66 |
+
return pt68
|
67 |
+
|
68 |
+
def init_plugin(self):
|
69 |
+
global create
|
70 |
+
|
71 |
+
if oDMDNet == None:
|
72 |
+
create(self.device)
|
73 |
+
|
74 |
+
|
75 |
+
def process(self, frame, params:dict):
|
76 |
+
if "face_detected" in params:
|
77 |
+
if not params["face_detected"]:
|
78 |
+
return frame
|
79 |
+
|
80 |
+
temp_frame = copy.copy(frame)
|
81 |
+
if "processed_faces" in params:
|
82 |
+
for face in params["processed_faces"]:
|
83 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
84 |
+
# padding_x = int((end_x - start_x) * 0.5)
|
85 |
+
# padding_y = int((end_y - start_y) * 0.5)
|
86 |
+
padding_x = 0
|
87 |
+
padding_y = 0
|
88 |
+
|
89 |
+
start_x = max(0, start_x - padding_x)
|
90 |
+
start_y = max(0, start_y - padding_y)
|
91 |
+
end_x = max(0, end_x + padding_x)
|
92 |
+
end_y = max(0, end_y + padding_y)
|
93 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
94 |
+
if temp_face.size:
|
95 |
+
temp_face = self.enhance_face(temp_face, face)
|
96 |
+
temp_face = cv2.resize(temp_face, (end_x - start_x,end_y - start_y), interpolation = cv2.INTER_LANCZOS4)
|
97 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
98 |
+
|
99 |
+
temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
|
100 |
+
return asarray(temp_frame)
|
101 |
+
|
102 |
+
|
103 |
+
def enhance_face(self, clip, face):
|
104 |
+
global device
|
105 |
+
|
106 |
+
lm106 = face.landmark_2d_106
|
107 |
+
lq_landmarks = asarray(self.landmarks106_to_68(lm106))
|
108 |
+
lq = read_img_tensor(clip, False)
|
109 |
+
|
110 |
+
LQLocs = get_component_location(lq_landmarks)
|
111 |
+
# generic
|
112 |
+
SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
with THREAD_LOCK_DMDNET:
|
116 |
+
try:
|
117 |
+
GenericResult, SpecificResult = oDMDNet(lq = lq.to(device), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
|
118 |
+
except Exception as e:
|
119 |
+
print(f'Error {e} there may be something wrong with the detected component locations.')
|
120 |
+
return clip
|
121 |
+
save_generic = GenericResult * 0.5 + 0.5
|
122 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
123 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
124 |
+
|
125 |
+
check_lq = lq * 0.5 + 0.5
|
126 |
+
check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
127 |
+
check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
|
128 |
+
enhanced_img = np.hstack((check_lq, save_generic))
|
129 |
+
temp_frame = save_generic.astype("uint8")
|
130 |
+
# temp_frame = save_generic.astype("uint8")
|
131 |
+
return temp_frame
|
132 |
+
|
133 |
+
|
134 |
+
def create(devicename):
|
135 |
+
global device, oDMDNet
|
136 |
+
|
137 |
+
test = "cuda" if torch.cuda.is_available() else "cpu"
|
138 |
+
device = torch.device(devicename)
|
139 |
+
oDMDNet = DMDNet().to(device)
|
140 |
+
weights = torch.load('./models/DMDNet.pth')
|
141 |
+
oDMDNet.load_state_dict(weights, strict=True)
|
142 |
+
|
143 |
+
oDMDNet.eval()
|
144 |
+
num_params = 0
|
145 |
+
for param in oDMDNet.parameters():
|
146 |
+
num_params += param.numel()
|
147 |
+
|
148 |
+
# print('{:>8s} : {}'.format('Using device', device))
|
149 |
+
# print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
def read_img_tensor(Img=None, return_landmark=True): #rgb -1~1
|
154 |
+
# Img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # BGR or G
|
155 |
+
if Img.ndim == 2:
|
156 |
+
Img = cv2.cvtColor(Img, cv2.COLOR_GRAY2RGB) # GGG
|
157 |
+
else:
|
158 |
+
Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) # RGB
|
159 |
+
|
160 |
+
if Img.shape[0] < 512 or Img.shape[1] < 512:
|
161 |
+
Img = cv2.resize(Img, (512,512), interpolation = cv2.INTER_AREA)
|
162 |
+
# ImgForLands = Img.copy()
|
163 |
+
|
164 |
+
Img = Img.transpose((2, 0, 1))/255.0
|
165 |
+
Img = torch.from_numpy(Img).float()
|
166 |
+
normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
|
167 |
+
ImgTensor = Img.unsqueeze(0)
|
168 |
+
return ImgTensor
|
169 |
+
|
170 |
+
|
171 |
+
def get_component_location(Landmarks, re_read=False):
|
172 |
+
if re_read:
|
173 |
+
ReadLandmark = []
|
174 |
+
with open(Landmarks,'r') as f:
|
175 |
+
for line in f:
|
176 |
+
tmp = [float(i) for i in line.split(' ') if i != '\n']
|
177 |
+
ReadLandmark.append(tmp)
|
178 |
+
ReadLandmark = np.array(ReadLandmark) #
|
179 |
+
Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
|
180 |
+
Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
|
181 |
+
Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
|
182 |
+
Map_LE = list(range(36,42))
|
183 |
+
Map_RE = list(range(42,48))
|
184 |
+
Map_NO = list(range(29,36))
|
185 |
+
Map_MO = list(range(48,68))
|
186 |
+
|
187 |
+
Landmarks[Landmarks>504]=504
|
188 |
+
Landmarks[Landmarks<8]=8
|
189 |
+
|
190 |
+
#left eye
|
191 |
+
Mean_LE = np.mean(Landmarks[Map_LE],0)
|
192 |
+
L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
|
193 |
+
L_LE1 = L_LE1 * 1.3
|
194 |
+
L_LE2 = L_LE1 / 1.9
|
195 |
+
L_LE_xy = L_LE1 + L_LE2
|
196 |
+
L_LE_lt = [L_LE_xy/2, L_LE1]
|
197 |
+
L_LE_rb = [L_LE_xy/2, L_LE2]
|
198 |
+
Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
|
199 |
+
|
200 |
+
#right eye
|
201 |
+
Mean_RE = np.mean(Landmarks[Map_RE],0)
|
202 |
+
L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
|
203 |
+
L_RE1 = L_RE1 * 1.3
|
204 |
+
L_RE2 = L_RE1 / 1.9
|
205 |
+
L_RE_xy = L_RE1 + L_RE2
|
206 |
+
L_RE_lt = [L_RE_xy/2, L_RE1]
|
207 |
+
L_RE_rb = [L_RE_xy/2, L_RE2]
|
208 |
+
Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
|
209 |
+
|
210 |
+
#nose
|
211 |
+
Mean_NO = np.mean(Landmarks[Map_NO],0)
|
212 |
+
L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
|
213 |
+
L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
|
214 |
+
L_NO_xy = L_NO1 * 2
|
215 |
+
L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
|
216 |
+
L_NO_rb = [L_NO_xy/2, L_NO2]
|
217 |
+
Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
|
218 |
+
|
219 |
+
#mouth
|
220 |
+
Mean_MO = np.mean(Landmarks[Map_MO],0)
|
221 |
+
L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
|
222 |
+
MO_O = Mean_MO - L_MO + 1
|
223 |
+
MO_T = Mean_MO + L_MO
|
224 |
+
MO_T[MO_T>510]=510
|
225 |
+
Location_MO = np.hstack((MO_O, MO_T)).astype(int)
|
226 |
+
return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
def calc_mean_std_4D(feat, eps=1e-5):
|
232 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
233 |
+
size = feat.size()
|
234 |
+
assert (len(size) == 4)
|
235 |
+
N, C = size[:2]
|
236 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
237 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
238 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
239 |
+
return feat_mean, feat_std
|
240 |
+
|
241 |
+
def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
|
242 |
+
size = content_feat.size()
|
243 |
+
style_mean, style_std = calc_mean_std_4D(style_feat)
|
244 |
+
content_mean, content_std = calc_mean_std_4D(content_feat)
|
245 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
246 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
247 |
+
|
248 |
+
|
249 |
+
def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
|
250 |
+
return nn.Sequential(
|
251 |
+
SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
252 |
+
nn.LeakyReLU(0.2),
|
253 |
+
SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
254 |
+
)
|
255 |
+
|
256 |
+
|
257 |
+
class MSDilateBlock(nn.Module):
|
258 |
+
def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
|
259 |
+
super(MSDilateBlock, self).__init__()
|
260 |
+
self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
|
261 |
+
self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
|
262 |
+
self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
|
263 |
+
self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
|
264 |
+
self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
|
265 |
+
def forward(self, x):
|
266 |
+
conv1 = self.conv1(x)
|
267 |
+
conv2 = self.conv2(x)
|
268 |
+
conv3 = self.conv3(x)
|
269 |
+
conv4 = self.conv4(x)
|
270 |
+
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
|
271 |
+
out = self.convi(cat) + x
|
272 |
+
return out
|
273 |
+
|
274 |
+
|
275 |
+
class AdaptiveInstanceNorm(nn.Module):
|
276 |
+
def __init__(self, in_channel):
|
277 |
+
super().__init__()
|
278 |
+
self.norm = nn.InstanceNorm2d(in_channel)
|
279 |
+
|
280 |
+
def forward(self, input, style):
|
281 |
+
style_mean, style_std = calc_mean_std_4D(style)
|
282 |
+
out = self.norm(input)
|
283 |
+
size = input.size()
|
284 |
+
out = style_std.expand(size) * out + style_mean.expand(size)
|
285 |
+
return out
|
286 |
+
|
287 |
+
class NoiseInjection(nn.Module):
|
288 |
+
def __init__(self, channel):
|
289 |
+
super().__init__()
|
290 |
+
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
291 |
+
def forward(self, image, noise):
|
292 |
+
if noise is None:
|
293 |
+
b, c, h, w = image.shape
|
294 |
+
noise = image.new_empty(b, 1, h, w).normal_()
|
295 |
+
return image + self.weight * noise
|
296 |
+
|
297 |
+
class StyledUpBlock(nn.Module):
|
298 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
|
299 |
+
super().__init__()
|
300 |
+
|
301 |
+
self.noise_inject = noise_inject
|
302 |
+
if upsample:
|
303 |
+
self.conv1 = nn.Sequential(
|
304 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
305 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
306 |
+
nn.LeakyReLU(0.2),
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
self.conv1 = nn.Sequential(
|
310 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
311 |
+
nn.LeakyReLU(0.2),
|
312 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
313 |
+
)
|
314 |
+
self.convup = nn.Sequential(
|
315 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
316 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
317 |
+
nn.LeakyReLU(0.2),
|
318 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
319 |
+
)
|
320 |
+
if self.noise_inject:
|
321 |
+
self.noise1 = NoiseInjection(out_channel)
|
322 |
+
|
323 |
+
self.lrelu1 = nn.LeakyReLU(0.2)
|
324 |
+
|
325 |
+
self.ScaleModel1 = nn.Sequential(
|
326 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
327 |
+
nn.LeakyReLU(0.2),
|
328 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
|
329 |
+
)
|
330 |
+
self.ShiftModel1 = nn.Sequential(
|
331 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
332 |
+
nn.LeakyReLU(0.2),
|
333 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
334 |
+
)
|
335 |
+
|
336 |
+
def forward(self, input, style):
|
337 |
+
out = self.conv1(input)
|
338 |
+
out = self.lrelu1(out)
|
339 |
+
Shift1 = self.ShiftModel1(style)
|
340 |
+
Scale1 = self.ScaleModel1(style)
|
341 |
+
out = out * Scale1 + Shift1
|
342 |
+
if self.noise_inject:
|
343 |
+
out = self.noise1(out, noise=None)
|
344 |
+
outup = self.convup(out)
|
345 |
+
return outup
|
346 |
+
|
347 |
+
|
348 |
+
####################################################################
|
349 |
+
###############Face Dictionary Generator
|
350 |
+
####################################################################
|
351 |
+
def AttentionBlock(in_channel):
|
352 |
+
return nn.Sequential(
|
353 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
354 |
+
nn.LeakyReLU(0.2),
|
355 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
356 |
+
)
|
357 |
+
|
358 |
+
class DilateResBlock(nn.Module):
|
359 |
+
def __init__(self, dim, dilation=[5,3] ):
|
360 |
+
super(DilateResBlock, self).__init__()
|
361 |
+
self.Res = nn.Sequential(
|
362 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
|
363 |
+
nn.LeakyReLU(0.2),
|
364 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
|
365 |
+
)
|
366 |
+
def forward(self, x):
|
367 |
+
out = x + self.Res(x)
|
368 |
+
return out
|
369 |
+
|
370 |
+
|
371 |
+
class KeyValue(nn.Module):
|
372 |
+
def __init__(self, indim, keydim, valdim):
|
373 |
+
super(KeyValue, self).__init__()
|
374 |
+
self.Key = nn.Sequential(
|
375 |
+
SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
376 |
+
nn.LeakyReLU(0.2),
|
377 |
+
SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
378 |
+
)
|
379 |
+
self.Value = nn.Sequential(
|
380 |
+
SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
381 |
+
nn.LeakyReLU(0.2),
|
382 |
+
SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
383 |
+
)
|
384 |
+
def forward(self, x):
|
385 |
+
return self.Key(x), self.Value(x)
|
386 |
+
|
387 |
+
class MaskAttention(nn.Module):
|
388 |
+
def __init__(self, indim):
|
389 |
+
super(MaskAttention, self).__init__()
|
390 |
+
self.conv1 = nn.Sequential(
|
391 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
392 |
+
nn.LeakyReLU(0.2),
|
393 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
394 |
+
)
|
395 |
+
self.conv2 = nn.Sequential(
|
396 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
397 |
+
nn.LeakyReLU(0.2),
|
398 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
399 |
+
)
|
400 |
+
self.conv3 = nn.Sequential(
|
401 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
402 |
+
nn.LeakyReLU(0.2),
|
403 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
404 |
+
)
|
405 |
+
self.convCat = nn.Sequential(
|
406 |
+
SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
407 |
+
nn.LeakyReLU(0.2),
|
408 |
+
SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
409 |
+
)
|
410 |
+
def forward(self, x, y, z):
|
411 |
+
c1 = self.conv1(x)
|
412 |
+
c2 = self.conv2(y)
|
413 |
+
c3 = self.conv3(z)
|
414 |
+
return self.convCat(torch.cat([c1,c2,c3], dim=1))
|
415 |
+
|
416 |
+
class Query(nn.Module):
|
417 |
+
def __init__(self, indim, quedim):
|
418 |
+
super(Query, self).__init__()
|
419 |
+
self.Query = nn.Sequential(
|
420 |
+
SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
421 |
+
nn.LeakyReLU(0.2),
|
422 |
+
SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
423 |
+
)
|
424 |
+
def forward(self, x):
|
425 |
+
return self.Query(x)
|
426 |
+
|
427 |
+
def roi_align_self(input, location, target_size):
|
428 |
+
return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],(target_size,target_size),mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
|
429 |
+
|
430 |
+
class FeatureExtractor(nn.Module):
|
431 |
+
def __init__(self, ngf = 64, key_scale = 4):#
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
self.key_scale = 4
|
435 |
+
self.part_sizes = np.array([80,80,50,110]) #
|
436 |
+
self.feature_sizes = np.array([256,128,64]) #
|
437 |
+
|
438 |
+
self.conv1 = nn.Sequential(
|
439 |
+
SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
|
440 |
+
nn.LeakyReLU(0.2),
|
441 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
442 |
+
)
|
443 |
+
self.conv2 = nn.Sequential(
|
444 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
445 |
+
nn.LeakyReLU(0.2),
|
446 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
|
447 |
+
)
|
448 |
+
self.res1 = DilateResBlock(ngf, [5,3])
|
449 |
+
self.res2 = DilateResBlock(ngf, [5,3])
|
450 |
+
|
451 |
+
|
452 |
+
self.conv3 = nn.Sequential(
|
453 |
+
SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
|
454 |
+
nn.LeakyReLU(0.2),
|
455 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
456 |
+
)
|
457 |
+
self.conv4 = nn.Sequential(
|
458 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
459 |
+
nn.LeakyReLU(0.2),
|
460 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
|
461 |
+
)
|
462 |
+
self.res3 = DilateResBlock(ngf*2, [3,1])
|
463 |
+
self.res4 = DilateResBlock(ngf*2, [3,1])
|
464 |
+
|
465 |
+
self.conv5 = nn.Sequential(
|
466 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
|
467 |
+
nn.LeakyReLU(0.2),
|
468 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
469 |
+
)
|
470 |
+
self.conv6 = nn.Sequential(
|
471 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
472 |
+
nn.LeakyReLU(0.2),
|
473 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
|
474 |
+
)
|
475 |
+
self.res5 = DilateResBlock(ngf*4, [1,1])
|
476 |
+
self.res6 = DilateResBlock(ngf*4, [1,1])
|
477 |
+
|
478 |
+
self.LE_256_Q = Query(ngf, ngf // self.key_scale)
|
479 |
+
self.RE_256_Q = Query(ngf, ngf // self.key_scale)
|
480 |
+
self.MO_256_Q = Query(ngf, ngf // self.key_scale)
|
481 |
+
self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
482 |
+
self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
483 |
+
self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
484 |
+
self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
485 |
+
self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
486 |
+
self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
487 |
+
|
488 |
+
|
489 |
+
def forward(self, img, locs):
|
490 |
+
le_location = locs[:,0,:].int().cpu().numpy()
|
491 |
+
re_location = locs[:,1,:].int().cpu().numpy()
|
492 |
+
no_location = locs[:,2,:].int().cpu().numpy()
|
493 |
+
mo_location = locs[:,3,:].int().cpu().numpy()
|
494 |
+
|
495 |
+
|
496 |
+
f1_0 = self.conv1(img)
|
497 |
+
f1_1 = self.res1(f1_0)
|
498 |
+
f2_0 = self.conv2(f1_1)
|
499 |
+
f2_1 = self.res2(f2_0)
|
500 |
+
|
501 |
+
f3_0 = self.conv3(f2_1)
|
502 |
+
f3_1 = self.res3(f3_0)
|
503 |
+
f4_0 = self.conv4(f3_1)
|
504 |
+
f4_1 = self.res4(f4_0)
|
505 |
+
|
506 |
+
f5_0 = self.conv5(f4_1)
|
507 |
+
f5_1 = self.res5(f5_0)
|
508 |
+
f6_0 = self.conv6(f5_1)
|
509 |
+
f6_1 = self.res6(f6_0)
|
510 |
+
|
511 |
+
|
512 |
+
####ROI Align
|
513 |
+
le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
|
514 |
+
re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
|
515 |
+
mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
|
516 |
+
|
517 |
+
le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
|
518 |
+
re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
|
519 |
+
mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
|
520 |
+
|
521 |
+
le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
|
522 |
+
re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
|
523 |
+
mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
|
524 |
+
|
525 |
+
|
526 |
+
le_256_q = self.LE_256_Q(le_part_256)
|
527 |
+
re_256_q = self.RE_256_Q(re_part_256)
|
528 |
+
mo_256_q = self.MO_256_Q(mo_part_256)
|
529 |
+
|
530 |
+
le_128_q = self.LE_128_Q(le_part_128)
|
531 |
+
re_128_q = self.RE_128_Q(re_part_128)
|
532 |
+
mo_128_q = self.MO_128_Q(mo_part_128)
|
533 |
+
|
534 |
+
le_64_q = self.LE_64_Q(le_part_64)
|
535 |
+
re_64_q = self.RE_64_Q(re_part_64)
|
536 |
+
mo_64_q = self.MO_64_Q(mo_part_64)
|
537 |
+
|
538 |
+
return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
|
539 |
+
'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
|
540 |
+
'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
|
541 |
+
'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
|
542 |
+
'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
|
543 |
+
'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
|
544 |
+
'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
|
545 |
+
|
546 |
+
|
547 |
+
class DMDNet(nn.Module):
|
548 |
+
def __init__(self, ngf = 64, banks_num = 128):
|
549 |
+
super().__init__()
|
550 |
+
self.part_sizes = np.array([80,80,50,110]) # size for 512
|
551 |
+
self.feature_sizes = np.array([256,128,64]) # size for 512
|
552 |
+
|
553 |
+
self.banks_num = banks_num
|
554 |
+
self.key_scale = 4
|
555 |
+
|
556 |
+
self.E_lq = FeatureExtractor(key_scale = self.key_scale)
|
557 |
+
self.E_hq = FeatureExtractor(key_scale = self.key_scale)
|
558 |
+
|
559 |
+
self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
560 |
+
self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
561 |
+
self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
562 |
+
|
563 |
+
self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
564 |
+
self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
565 |
+
self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
566 |
+
|
567 |
+
self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
568 |
+
self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
569 |
+
self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
570 |
+
|
571 |
+
|
572 |
+
self.LE_256_Attention = AttentionBlock(64)
|
573 |
+
self.RE_256_Attention = AttentionBlock(64)
|
574 |
+
self.MO_256_Attention = AttentionBlock(64)
|
575 |
+
|
576 |
+
self.LE_128_Attention = AttentionBlock(128)
|
577 |
+
self.RE_128_Attention = AttentionBlock(128)
|
578 |
+
self.MO_128_Attention = AttentionBlock(128)
|
579 |
+
|
580 |
+
self.LE_64_Attention = AttentionBlock(256)
|
581 |
+
self.RE_64_Attention = AttentionBlock(256)
|
582 |
+
self.MO_64_Attention = AttentionBlock(256)
|
583 |
+
|
584 |
+
self.LE_256_Mask = MaskAttention(64)
|
585 |
+
self.RE_256_Mask = MaskAttention(64)
|
586 |
+
self.MO_256_Mask = MaskAttention(64)
|
587 |
+
|
588 |
+
self.LE_128_Mask = MaskAttention(128)
|
589 |
+
self.RE_128_Mask = MaskAttention(128)
|
590 |
+
self.MO_128_Mask = MaskAttention(128)
|
591 |
+
|
592 |
+
self.LE_64_Mask = MaskAttention(256)
|
593 |
+
self.RE_64_Mask = MaskAttention(256)
|
594 |
+
self.MO_64_Mask = MaskAttention(256)
|
595 |
+
|
596 |
+
self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
|
597 |
+
|
598 |
+
self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
|
599 |
+
self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
|
600 |
+
self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
|
601 |
+
self.up4 = nn.Sequential(
|
602 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
603 |
+
nn.LeakyReLU(0.2),
|
604 |
+
UpResBlock(ngf),
|
605 |
+
UpResBlock(ngf),
|
606 |
+
SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
|
607 |
+
nn.Tanh()
|
608 |
+
)
|
609 |
+
|
610 |
+
# define generic memory, revise register_buffer to register_parameter for backward update
|
611 |
+
self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
|
612 |
+
self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
|
613 |
+
self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
|
614 |
+
self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
|
615 |
+
self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
|
616 |
+
self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
|
617 |
+
|
618 |
+
|
619 |
+
self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
|
620 |
+
self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
|
621 |
+
self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
|
622 |
+
self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
|
623 |
+
self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
|
624 |
+
self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
|
625 |
+
|
626 |
+
self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
|
627 |
+
self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
|
628 |
+
self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
|
629 |
+
self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
|
630 |
+
self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
|
631 |
+
self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
|
632 |
+
|
633 |
+
|
634 |
+
def readMem(self, k, v, q):
|
635 |
+
sim = F.conv2d(q, k)
|
636 |
+
score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
|
637 |
+
sb,sn,sw,sh = score.size()
|
638 |
+
s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
|
639 |
+
vb,vn,vw,vh = v.size()
|
640 |
+
v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
|
641 |
+
mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
|
642 |
+
max_inds = torch.argmax(score, dim=1).squeeze()
|
643 |
+
return mem_out, max_inds
|
644 |
+
|
645 |
+
|
646 |
+
def memorize(self, img, locs):
|
647 |
+
fs = self.E_hq(img, locs)
|
648 |
+
LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
|
649 |
+
RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
|
650 |
+
MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
|
651 |
+
|
652 |
+
LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
|
653 |
+
RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
|
654 |
+
MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
|
655 |
+
|
656 |
+
LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
|
657 |
+
RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
|
658 |
+
MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
|
659 |
+
|
660 |
+
Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
|
661 |
+
Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
|
662 |
+
Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
|
663 |
+
|
664 |
+
FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
|
665 |
+
FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
|
666 |
+
FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
|
667 |
+
|
668 |
+
return Mem256, Mem128, Mem64
|
669 |
+
|
670 |
+
def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
|
671 |
+
le_256_q = fs_in['le_256_q']
|
672 |
+
re_256_q = fs_in['re_256_q']
|
673 |
+
mo_256_q = fs_in['mo_256_q']
|
674 |
+
|
675 |
+
le_128_q = fs_in['le_128_q']
|
676 |
+
re_128_q = fs_in['re_128_q']
|
677 |
+
mo_128_q = fs_in['mo_128_q']
|
678 |
+
|
679 |
+
le_64_q = fs_in['le_64_q']
|
680 |
+
re_64_q = fs_in['re_64_q']
|
681 |
+
mo_64_q = fs_in['mo_64_q']
|
682 |
+
|
683 |
+
|
684 |
+
####for 256
|
685 |
+
le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
|
686 |
+
re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
|
687 |
+
mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
|
688 |
+
|
689 |
+
le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
|
690 |
+
re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
|
691 |
+
mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
|
692 |
+
|
693 |
+
le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
|
694 |
+
re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
|
695 |
+
mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
|
696 |
+
|
697 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
698 |
+
le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
|
699 |
+
re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
|
700 |
+
mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
|
701 |
+
le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
|
702 |
+
le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
|
703 |
+
re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
|
704 |
+
re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
|
705 |
+
mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
|
706 |
+
mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
|
707 |
+
|
708 |
+
le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
|
709 |
+
re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
|
710 |
+
mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
|
711 |
+
le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
|
712 |
+
le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
|
713 |
+
re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
|
714 |
+
re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
|
715 |
+
mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
|
716 |
+
mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
|
717 |
+
|
718 |
+
le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
|
719 |
+
re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
|
720 |
+
mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
|
721 |
+
le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
|
722 |
+
le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
|
723 |
+
re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
|
724 |
+
re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
|
725 |
+
mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
|
726 |
+
mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
|
727 |
+
else:
|
728 |
+
le_256_mem = le_256_mem_g
|
729 |
+
re_256_mem = re_256_mem_g
|
730 |
+
mo_256_mem = mo_256_mem_g
|
731 |
+
le_128_mem = le_128_mem_g
|
732 |
+
re_128_mem = re_128_mem_g
|
733 |
+
mo_128_mem = mo_128_mem_g
|
734 |
+
le_64_mem = le_64_mem_g
|
735 |
+
re_64_mem = re_64_mem_g
|
736 |
+
mo_64_mem = mo_64_mem_g
|
737 |
+
|
738 |
+
le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
|
739 |
+
re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
|
740 |
+
mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
|
741 |
+
|
742 |
+
####for 128
|
743 |
+
le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
|
744 |
+
re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
|
745 |
+
mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
|
746 |
+
|
747 |
+
####for 64
|
748 |
+
le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
|
749 |
+
re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
|
750 |
+
mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
|
751 |
+
|
752 |
+
|
753 |
+
EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
|
754 |
+
EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
|
755 |
+
EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
|
756 |
+
Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
|
757 |
+
Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
|
758 |
+
Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
|
759 |
+
return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
|
760 |
+
|
761 |
+
def reconstruct(self, fs_in, locs, memstar):
|
762 |
+
le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
|
763 |
+
le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
|
764 |
+
le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
|
765 |
+
|
766 |
+
le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
|
767 |
+
re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
|
768 |
+
mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
|
769 |
+
|
770 |
+
le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
|
771 |
+
re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
|
772 |
+
mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
|
773 |
+
|
774 |
+
le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
|
775 |
+
re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
|
776 |
+
mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
|
777 |
+
|
778 |
+
|
779 |
+
le_location = locs[:,0,:]
|
780 |
+
re_location = locs[:,1,:]
|
781 |
+
mo_location = locs[:,3,:]
|
782 |
+
le_location = le_location.cpu().int().numpy()
|
783 |
+
re_location = re_location.cpu().int().numpy()
|
784 |
+
mo_location = mo_location.cpu().int().numpy()
|
785 |
+
|
786 |
+
up_in_256 = fs_in['f256'].clone()# * 0
|
787 |
+
up_in_128 = fs_in['f128'].clone()# * 0
|
788 |
+
up_in_64 = fs_in['f64'].clone()# * 0
|
789 |
+
|
790 |
+
for i in range(fs_in['f256'].size(0)):
|
791 |
+
up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
|
792 |
+
up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
|
793 |
+
up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
|
794 |
+
|
795 |
+
up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
|
796 |
+
up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
|
797 |
+
up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
|
798 |
+
|
799 |
+
up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
|
800 |
+
up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
|
801 |
+
up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
|
802 |
+
|
803 |
+
ms_in_64 = self.MSDilate(fs_in['f64'].clone())
|
804 |
+
fea_up1 = self.up1(ms_in_64, up_in_64)
|
805 |
+
fea_up2 = self.up2(fea_up1, up_in_128) #
|
806 |
+
fea_up3 = self.up3(fea_up2, up_in_256) #
|
807 |
+
output = self.up4(fea_up3) #
|
808 |
+
return output
|
809 |
+
|
810 |
+
def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
|
811 |
+
return self.memorize(sp_imgs, sp_locs)
|
812 |
+
|
813 |
+
def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
|
814 |
+
fs_in = self.E_lq(lq, loc) # low quality images
|
815 |
+
GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
|
816 |
+
GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
|
817 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
818 |
+
GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
|
819 |
+
GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
|
820 |
+
else:
|
821 |
+
GSOut = None
|
822 |
+
return GeOut, GSOut
|
823 |
+
|
824 |
+
class UpResBlock(nn.Module):
|
825 |
+
def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
|
826 |
+
super(UpResBlock, self).__init__()
|
827 |
+
self.Model = nn.Sequential(
|
828 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
829 |
+
nn.LeakyReLU(0.2),
|
830 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
831 |
+
)
|
832 |
+
def forward(self, x):
|
833 |
+
out = x + self.Model(x)
|
834 |
+
return out
|
835 |
+
|
plugins/plugin_faceswap.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
2 |
+
from roop.face_helper import get_one_face, get_many_faces, swap_face
|
3 |
+
import os
|
4 |
+
from roop.utilities import compute_cosine_distance
|
5 |
+
|
6 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
7 |
+
|
8 |
+
# start function
|
9 |
+
def start(core:ChainImgProcessor):
|
10 |
+
manifest = { # plugin settings
|
11 |
+
"name": "Faceswap", # name
|
12 |
+
"version": "1.0", # version
|
13 |
+
|
14 |
+
"default_options": {
|
15 |
+
"swap_mode": "selected",
|
16 |
+
"max_distance": 0.65, # max distance to detect face similarity
|
17 |
+
},
|
18 |
+
"img_processor": {
|
19 |
+
"faceswap": Faceswap
|
20 |
+
}
|
21 |
+
}
|
22 |
+
return manifest
|
23 |
+
|
24 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
class Faceswap(ChainImgPlugin):
|
29 |
+
|
30 |
+
def init_plugin(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
def process(self, frame, params:dict):
|
35 |
+
if not "input_face_datas" in params or len(params["input_face_datas"]) < 1:
|
36 |
+
params["face_detected"] = False
|
37 |
+
return frame
|
38 |
+
|
39 |
+
temp_frame = frame
|
40 |
+
params["face_detected"] = True
|
41 |
+
params["processed_faces"] = []
|
42 |
+
|
43 |
+
if params["swap_mode"] == "first":
|
44 |
+
face = get_one_face(frame)
|
45 |
+
if face is None:
|
46 |
+
params["face_detected"] = False
|
47 |
+
return frame
|
48 |
+
params["processed_faces"].append(face)
|
49 |
+
frame = swap_face(params["input_face_datas"][0], face, frame)
|
50 |
+
return frame
|
51 |
+
|
52 |
+
else:
|
53 |
+
faces = get_many_faces(frame)
|
54 |
+
if(len(faces) < 1):
|
55 |
+
params["face_detected"] = False
|
56 |
+
return frame
|
57 |
+
|
58 |
+
dist_threshold = params["face_distance_threshold"]
|
59 |
+
|
60 |
+
if params["swap_mode"] == "all":
|
61 |
+
for sf in params["input_face_datas"]:
|
62 |
+
for face in faces:
|
63 |
+
params["processed_faces"].append(face)
|
64 |
+
temp_frame = swap_face(sf, face, temp_frame)
|
65 |
+
return temp_frame
|
66 |
+
|
67 |
+
elif params["swap_mode"] == "selected":
|
68 |
+
for i,tf in enumerate(params["target_face_datas"]):
|
69 |
+
for face in faces:
|
70 |
+
if compute_cosine_distance(tf.embedding, face.embedding) <= dist_threshold:
|
71 |
+
temp_frame = swap_face(params["input_face_datas"][i], face, temp_frame)
|
72 |
+
params["processed_faces"].append(face)
|
73 |
+
break
|
74 |
+
|
75 |
+
elif params["swap_mode"] == "all_female" or params["swap_mode"] == "all_male":
|
76 |
+
gender = 'F' if params["swap_mode"] == "all_female" else 'M'
|
77 |
+
face_found = False
|
78 |
+
for face in faces:
|
79 |
+
if face.sex == gender:
|
80 |
+
face_found = True
|
81 |
+
if face_found:
|
82 |
+
params["processed_faces"].append(face)
|
83 |
+
temp_frame = swap_face(params["input_face_datas"][0], face, temp_frame)
|
84 |
+
face_found = False
|
85 |
+
|
86 |
+
return temp_frame
|
plugins/plugin_gfpgan.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
2 |
+
import os
|
3 |
+
import gfpgan
|
4 |
+
import threading
|
5 |
+
from PIL import Image
|
6 |
+
from numpy import asarray
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
from roop.utilities import resolve_relative_path, conditional_download
|
10 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
11 |
+
|
12 |
+
model_gfpgan = None
|
13 |
+
THREAD_LOCK_GFPGAN = threading.Lock()
|
14 |
+
|
15 |
+
|
16 |
+
# start function
|
17 |
+
def start(core:ChainImgProcessor):
|
18 |
+
manifest = { # plugin settings
|
19 |
+
"name": "GFPGAN", # name
|
20 |
+
"version": "1.4", # version
|
21 |
+
|
22 |
+
"default_options": {},
|
23 |
+
"img_processor": {
|
24 |
+
"gfpgan": GFPGAN
|
25 |
+
}
|
26 |
+
}
|
27 |
+
return manifest
|
28 |
+
|
29 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
class GFPGAN(ChainImgPlugin):
|
34 |
+
|
35 |
+
def init_plugin(self):
|
36 |
+
global model_gfpgan
|
37 |
+
|
38 |
+
if model_gfpgan is None:
|
39 |
+
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
|
40 |
+
model_gfpgan = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=self.device) # type: ignore[attr-defined]
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def process(self, frame, params:dict):
|
45 |
+
import copy
|
46 |
+
|
47 |
+
global model_gfpgan
|
48 |
+
|
49 |
+
if model_gfpgan is None:
|
50 |
+
return frame
|
51 |
+
|
52 |
+
if "face_detected" in params:
|
53 |
+
if not params["face_detected"]:
|
54 |
+
return frame
|
55 |
+
# don't touch original
|
56 |
+
temp_frame = copy.copy(frame)
|
57 |
+
if "processed_faces" in params:
|
58 |
+
for face in params["processed_faces"]:
|
59 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
60 |
+
padding_x = int((end_x - start_x) * 0.5)
|
61 |
+
padding_y = int((end_y - start_y) * 0.5)
|
62 |
+
start_x = max(0, start_x - padding_x)
|
63 |
+
start_y = max(0, start_y - padding_y)
|
64 |
+
end_x = max(0, end_x + padding_x)
|
65 |
+
end_y = max(0, end_y + padding_y)
|
66 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
67 |
+
if temp_face.size:
|
68 |
+
with THREAD_LOCK_GFPGAN:
|
69 |
+
_, _, temp_face = model_gfpgan.enhance(
|
70 |
+
temp_face,
|
71 |
+
paste_back=True
|
72 |
+
)
|
73 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
74 |
+
else:
|
75 |
+
with THREAD_LOCK_GFPGAN:
|
76 |
+
_, _, temp_frame = model_gfpgan.enhance(
|
77 |
+
temp_frame,
|
78 |
+
paste_back=True
|
79 |
+
)
|
80 |
+
|
81 |
+
if not "blend_ratio" in params:
|
82 |
+
return temp_frame
|
83 |
+
|
84 |
+
temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
|
85 |
+
return asarray(temp_frame)
|
plugins/plugin_txt2clip.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import threading
|
6 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
7 |
+
from torchvision import transforms
|
8 |
+
from clip.clipseg import CLIPDensePredT
|
9 |
+
from numpy import asarray
|
10 |
+
|
11 |
+
|
12 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
13 |
+
|
14 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
15 |
+
|
16 |
+
model_clip = None
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
# start function
|
22 |
+
def start(core:ChainImgProcessor):
|
23 |
+
manifest = { # plugin settings
|
24 |
+
"name": "Text2Clip", # name
|
25 |
+
"version": "1.0", # version
|
26 |
+
|
27 |
+
"default_options": {
|
28 |
+
},
|
29 |
+
"img_processor": {
|
30 |
+
"txt2clip": Text2Clip
|
31 |
+
}
|
32 |
+
}
|
33 |
+
return manifest
|
34 |
+
|
35 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class Text2Clip(ChainImgPlugin):
|
41 |
+
|
42 |
+
def load_clip_model(self):
|
43 |
+
global model_clip
|
44 |
+
|
45 |
+
if model_clip is None:
|
46 |
+
device = torch.device(super().device)
|
47 |
+
model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
48 |
+
model_clip.eval();
|
49 |
+
model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
|
50 |
+
model_clip.to(device)
|
51 |
+
|
52 |
+
|
53 |
+
def init_plugin(self):
|
54 |
+
self.load_clip_model()
|
55 |
+
|
56 |
+
def process(self, frame, params:dict):
|
57 |
+
if "face_detected" in params:
|
58 |
+
if not params["face_detected"]:
|
59 |
+
return frame
|
60 |
+
|
61 |
+
return self.mask_original(params["original_frame"], frame, params["clip_prompt"])
|
62 |
+
|
63 |
+
|
64 |
+
def mask_original(self, img1, img2, keywords):
|
65 |
+
global model_clip
|
66 |
+
|
67 |
+
source_image_small = cv2.resize(img1, (256,256))
|
68 |
+
|
69 |
+
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
|
70 |
+
mask_border = 1
|
71 |
+
l = 0
|
72 |
+
t = 0
|
73 |
+
r = 1
|
74 |
+
b = 1
|
75 |
+
|
76 |
+
mask_blur = 5
|
77 |
+
clip_blur = 5
|
78 |
+
|
79 |
+
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
|
80 |
+
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
|
81 |
+
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
|
82 |
+
img_mask /= 255
|
83 |
+
|
84 |
+
|
85 |
+
input_image = source_image_small
|
86 |
+
|
87 |
+
transform = transforms.Compose([
|
88 |
+
transforms.ToTensor(),
|
89 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
90 |
+
transforms.Resize((256, 256)),
|
91 |
+
])
|
92 |
+
img = transform(input_image).unsqueeze(0)
|
93 |
+
|
94 |
+
thresh = 0.5
|
95 |
+
prompts = keywords.split(',')
|
96 |
+
with THREAD_LOCK_CLIP:
|
97 |
+
with torch.no_grad():
|
98 |
+
preds = model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
|
99 |
+
clip_mask = torch.sigmoid(preds[0][0])
|
100 |
+
for i in range(len(prompts)-1):
|
101 |
+
clip_mask += torch.sigmoid(preds[i+1][0])
|
102 |
+
|
103 |
+
clip_mask = clip_mask.data.cpu().numpy()
|
104 |
+
np.clip(clip_mask, 0, 1)
|
105 |
+
|
106 |
+
clip_mask[clip_mask>thresh] = 1.0
|
107 |
+
clip_mask[clip_mask<=thresh] = 0.0
|
108 |
+
kernel = np.ones((5, 5), np.float32)
|
109 |
+
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
|
110 |
+
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
|
111 |
+
|
112 |
+
img_mask *= clip_mask
|
113 |
+
img_mask[img_mask<0.0] = 0.0
|
114 |
+
|
115 |
+
img_mask = cv2.resize(img_mask, (img2.shape[1], img2.shape[0]))
|
116 |
+
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
117 |
+
|
118 |
+
target = img2.astype(np.float32)
|
119 |
+
result = (1-img_mask) * target
|
120 |
+
result += img_mask * img1.astype(np.float32)
|
121 |
+
return np.uint8(result)
|
122 |
+
|
roop-unleashed.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1abfc4e80fb1e9e8eb3381f1d46193051b683ea452595a189bb5d647dfe7b6b
|
3 |
+
size 5953
|