Spaces:
Running
on
Zero
Running
on
Zero
modify to support ZeroGPU
Browse files- inference_utils.py +4 -3
- spiga_draw.py +42 -66
inference_utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
import random
|
|
|
4 |
|
5 |
seed = 1024
|
6 |
random.seed(seed)
|
@@ -12,7 +14,6 @@ torch.backends.cudnn.benchmark = False
|
|
12 |
|
13 |
from PIL import Image
|
14 |
from gdown import download_folder
|
15 |
-
from facelib import FaceDetector
|
16 |
from spiga_draw import spiga_process, spiga_segmentation
|
17 |
|
18 |
from pipeline_sd15 import StableDiffusionControlNetPipeline
|
@@ -20,11 +21,11 @@ from diffusers import DDIMScheduler, ControlNetModel
|
|
20 |
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
|
21 |
from detail_encoder.encoder_plus import detail_encoder
|
22 |
|
23 |
-
detector = FaceDetector(weight_path="./models/mobilenet0.25_Final.pth")
|
24 |
|
25 |
|
26 |
def get_draw(pil_img, size):
|
27 |
-
|
|
|
28 |
if spigas == False:
|
29 |
width, height = pil_img.size
|
30 |
black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
|
|
|
1 |
import os
|
2 |
+
import cv2
|
3 |
import torch
|
4 |
import random
|
5 |
+
import numpy as np
|
6 |
|
7 |
seed = 1024
|
8 |
random.seed(seed)
|
|
|
14 |
|
15 |
from PIL import Image
|
16 |
from gdown import download_folder
|
|
|
17 |
from spiga_draw import spiga_process, spiga_segmentation
|
18 |
|
19 |
from pipeline_sd15 import StableDiffusionControlNetPipeline
|
|
|
21 |
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
|
22 |
from detail_encoder.encoder_plus import detail_encoder
|
23 |
|
|
|
24 |
|
25 |
|
26 |
def get_draw(pil_img, size):
|
27 |
+
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
28 |
+
spigas = spiga_process(cv2_img)
|
29 |
if spigas == False:
|
30 |
width, height = pil_img.size
|
31 |
black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
|
spiga_draw.py
CHANGED
@@ -1,22 +1,26 @@
|
|
1 |
import os
|
2 |
import cv2
|
3 |
import tqdm
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
-
from
|
7 |
from spiga.inference.config import ModelConfig
|
8 |
from spiga.inference.framework import SPIGAFramework
|
9 |
|
10 |
-
# SPIGA
|
11 |
spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
|
12 |
if not os.path.exists(spiga_ckpt):
|
13 |
from gdown import download
|
|
|
14 |
spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
|
15 |
download(id=spiga_file_id, output=spiga_ckpt)
|
16 |
spiga_config = ModelConfig("300wpublic")
|
17 |
spiga_config.load_model_url = False
|
18 |
spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
|
19 |
processor = SPIGAFramework(spiga_config)
|
|
|
|
|
20 |
|
21 |
def center_crop(image, size):
|
22 |
width, height = image.size
|
@@ -27,6 +31,7 @@ def center_crop(image, size):
|
|
27 |
cropped_image = image.crop((left, top, right, bottom))
|
28 |
return cropped_image
|
29 |
|
|
|
30 |
def resize(image, size):
|
31 |
width, height = image.size
|
32 |
if width > height:
|
@@ -40,33 +45,32 @@ def resize(image, size):
|
|
40 |
resized_image = image.resize((new_width, new_height))
|
41 |
return resized_image
|
42 |
|
|
|
43 |
def preprocess(example, name, path):
|
44 |
image = resize(example, 512)
|
45 |
# 调用中心剪裁函数
|
46 |
cropped_image = center_crop(image, 512)
|
47 |
# 保存剪裁后的图像
|
48 |
-
cropped_image.save(path+name)
|
49 |
return cropped_image
|
50 |
|
|
|
51 |
# We obtain the bbox from the existing landmarks in the dataset.
|
52 |
# We could use `dlib`, but this should be faster.
|
53 |
# Note that the `landmarks` are stored as strings.
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
faces
|
58 |
-
|
59 |
-
|
60 |
-
for box in boxes:
|
61 |
-
x, y, x1, y1 = box
|
62 |
-
box = x, y, x1 - x, y1 - y
|
63 |
-
box_ls.append(box)
|
64 |
-
if len(box_ls) == 0:
|
65 |
-
return []
|
66 |
else:
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
def parse_landmarks(landmarks):
|
@@ -88,29 +92,28 @@ def bbox_from_landmarks(landmarks_):
|
|
88 |
height = y_max - y_min
|
89 |
|
90 |
# Give it a little room; I think it works anyway
|
91 |
-
x_min
|
92 |
-
y_min
|
93 |
-
width
|
94 |
height += 10
|
95 |
bbox.append((x_min, y_min, width, height))
|
96 |
return bbox
|
97 |
|
98 |
|
99 |
-
def spiga_process(
|
100 |
-
ldms = get_landmarks(
|
101 |
|
102 |
if len(ldms) == 0:
|
103 |
return False
|
104 |
|
105 |
else:
|
106 |
-
image = example
|
107 |
image = np.array(image)
|
108 |
# BGR
|
109 |
-
image
|
110 |
-
bbox
|
111 |
-
features
|
112 |
landmarks = features["landmarks"]
|
113 |
-
spigas
|
114 |
return spigas
|
115 |
|
116 |
|
@@ -128,7 +131,7 @@ from matplotlib.path import Path
|
|
128 |
import PIL
|
129 |
|
130 |
|
131 |
-
def get_patch(landmarks, color=
|
132 |
contour = landmarks
|
133 |
ops = [Path.MOVETO] + [Path.LINETO] * (len(contour) - 1)
|
134 |
facecolor = (0, 0, 0, 0) # Transparent fill color, if open
|
@@ -142,10 +145,11 @@ def get_patch(landmarks, color='lime', closed=False):
|
|
142 |
|
143 |
# Draw to a buffer.
|
144 |
|
|
|
145 |
def conditioning_from_landmarks(landmarks_, size=512):
|
146 |
# Precisely control output image size
|
147 |
dpi = 72
|
148 |
-
fig, ax = plt.subplots(1, figsize=[size / dpi, size / dpi], tight_layout={
|
149 |
fig.set_dpi(dpi)
|
150 |
|
151 |
black = np.zeros((size, size, 3))
|
@@ -153,14 +157,14 @@ def conditioning_from_landmarks(landmarks_, size=512):
|
|
153 |
|
154 |
for landmarks in landmarks_:
|
155 |
face_patch = get_patch(landmarks[0:17])
|
156 |
-
l_eyebrow = get_patch(landmarks[17:22], color=
|
157 |
-
r_eyebrow = get_patch(landmarks[22:27], color=
|
158 |
-
nose_v = get_patch(landmarks[27:31], color=
|
159 |
-
nose_h = get_patch(landmarks[31:36], color=
|
160 |
-
l_eye = get_patch(landmarks[36:42], color=
|
161 |
-
r_eye = get_patch(landmarks[42:48], color=
|
162 |
-
outer_lips = get_patch(landmarks[48:60], color=
|
163 |
-
inner_lips = get_patch(landmarks[60:68], color=
|
164 |
|
165 |
ax.add_patch(face_patch)
|
166 |
ax.add_patch(l_eyebrow)
|
@@ -172,7 +176,7 @@ def conditioning_from_landmarks(landmarks_, size=512):
|
|
172 |
ax.add_patch(outer_lips)
|
173 |
ax.add_patch(inner_lips)
|
174 |
|
175 |
-
plt.axis(
|
176 |
|
177 |
fig.canvas.draw()
|
178 |
buffer, (width, height) = fig.canvas.print_to_buffer()
|
@@ -189,31 +193,3 @@ def spiga_segmentation(spiga, size):
|
|
189 |
landmarks = spiga
|
190 |
spiga_seg = conditioning_from_landmarks(landmarks, size=size)
|
191 |
return spiga_seg
|
192 |
-
|
193 |
-
|
194 |
-
if __name__ == '__main__':
|
195 |
-
# ## Obtain SPIGA features
|
196 |
-
processor = SPIGAFramework(ModelConfig("300wpublic"))
|
197 |
-
detector = FaceDetector(weight_path="/share2/zhangyuxuan/project/train_ip_cn/datasets/make_kps/pretrained_models/mobilenet0.25_Final.pth")
|
198 |
-
|
199 |
-
id_folder = "/share2/zhangyuxuan/project/train_ip_cn/test_img_2/id/"
|
200 |
-
pose_folder = "/share2/zhangyuxuan/project/train_ip_cn/test_img_2/pose/"
|
201 |
-
|
202 |
-
if not os.path.exists(pose_folder):
|
203 |
-
os.makedirs(pose_folder)
|
204 |
-
|
205 |
-
pbar = tqdm.tqdm(os.listdir(id_folder))
|
206 |
-
for name in pbar:
|
207 |
-
face = Image.open(id_folder+name).convert("RGB").resize((512, 512))
|
208 |
-
face.save(id_folder+name)
|
209 |
-
spigas = spiga_process(face, detector)
|
210 |
-
if spigas == False:
|
211 |
-
height = 512
|
212 |
-
width = 512
|
213 |
-
channels = 3
|
214 |
-
black_image = np.zeros((height, width, channels), dtype=np.uint8)
|
215 |
-
black_image_cv2 = cv2.cvtColor(black_image, cv2.COLOR_RGB2BGR)
|
216 |
-
continue
|
217 |
-
else:
|
218 |
-
spigas_faces = spiga_segmentation(spigas)
|
219 |
-
spigas_faces.save(pose_folder + name)
|
|
|
1 |
import os
|
2 |
import cv2
|
3 |
import tqdm
|
4 |
+
import torch
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
+
from batch_face import RetinaFace
|
8 |
from spiga.inference.config import ModelConfig
|
9 |
from spiga.inference.framework import SPIGAFramework
|
10 |
|
11 |
+
# The SPIGA checkpoint download often fails, so we downloaded it manually and will load it from a local path instead.
|
12 |
spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
|
13 |
if not os.path.exists(spiga_ckpt):
|
14 |
from gdown import download
|
15 |
+
|
16 |
spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
|
17 |
download(id=spiga_file_id, output=spiga_ckpt)
|
18 |
spiga_config = ModelConfig("300wpublic")
|
19 |
spiga_config.load_model_url = False
|
20 |
spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
|
21 |
processor = SPIGAFramework(spiga_config)
|
22 |
+
face_detector = RetinaFace(gpu_id=0) if torch.cuda.is_available() else RetinaFace(gpu_id=-1)
|
23 |
+
|
24 |
|
25 |
def center_crop(image, size):
|
26 |
width, height = image.size
|
|
|
31 |
cropped_image = image.crop((left, top, right, bottom))
|
32 |
return cropped_image
|
33 |
|
34 |
+
|
35 |
def resize(image, size):
|
36 |
width, height = image.size
|
37 |
if width > height:
|
|
|
45 |
resized_image = image.resize((new_width, new_height))
|
46 |
return resized_image
|
47 |
|
48 |
+
|
49 |
def preprocess(example, name, path):
|
50 |
image = resize(example, 512)
|
51 |
# 调用中心剪裁函数
|
52 |
cropped_image = center_crop(image, 512)
|
53 |
# 保存剪裁后的图像
|
54 |
+
cropped_image.save(path + name)
|
55 |
return cropped_image
|
56 |
|
57 |
+
|
58 |
# We obtain the bbox from the existing landmarks in the dataset.
|
59 |
# We could use `dlib`, but this should be faster.
|
60 |
# Note that the `landmarks` are stored as strings.
|
61 |
|
62 |
+
|
63 |
+
def get_landmarks(frame_cv2):
|
64 |
+
faces = face_detector(frame_cv2, cv=True)
|
65 |
+
if len(faces) == 0:
|
66 |
+
raise ValueError("Face is not detected")
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
+
coord = faces[0][0]
|
69 |
+
x, y, x1, y1 = coord
|
70 |
+
box = x, y, x1 - x, y1 - y
|
71 |
+
features = processor.inference(frame_cv2, [box])
|
72 |
+
landmarks = np.array(features["landmarks"])
|
73 |
+
return landmarks
|
74 |
|
75 |
|
76 |
def parse_landmarks(landmarks):
|
|
|
92 |
height = y_max - y_min
|
93 |
|
94 |
# Give it a little room; I think it works anyway
|
95 |
+
x_min -= 5
|
96 |
+
y_min -= 5
|
97 |
+
width += 10
|
98 |
height += 10
|
99 |
bbox.append((x_min, y_min, width, height))
|
100 |
return bbox
|
101 |
|
102 |
|
103 |
+
def spiga_process(image):
|
104 |
+
ldms = get_landmarks(image)
|
105 |
|
106 |
if len(ldms) == 0:
|
107 |
return False
|
108 |
|
109 |
else:
|
|
|
110 |
image = np.array(image)
|
111 |
# BGR
|
112 |
+
image = image[:, :, ::-1]
|
113 |
+
bbox = bbox_from_landmarks(ldms)
|
114 |
+
features = processor.inference(image, [*bbox])
|
115 |
landmarks = features["landmarks"]
|
116 |
+
spigas = landmarks
|
117 |
return spigas
|
118 |
|
119 |
|
|
|
131 |
import PIL
|
132 |
|
133 |
|
134 |
+
def get_patch(landmarks, color="lime", closed=False):
|
135 |
contour = landmarks
|
136 |
ops = [Path.MOVETO] + [Path.LINETO] * (len(contour) - 1)
|
137 |
facecolor = (0, 0, 0, 0) # Transparent fill color, if open
|
|
|
145 |
|
146 |
# Draw to a buffer.
|
147 |
|
148 |
+
|
149 |
def conditioning_from_landmarks(landmarks_, size=512):
|
150 |
# Precisely control output image size
|
151 |
dpi = 72
|
152 |
+
fig, ax = plt.subplots(1, figsize=[size / dpi, size / dpi], tight_layout={"pad": 0})
|
153 |
fig.set_dpi(dpi)
|
154 |
|
155 |
black = np.zeros((size, size, 3))
|
|
|
157 |
|
158 |
for landmarks in landmarks_:
|
159 |
face_patch = get_patch(landmarks[0:17])
|
160 |
+
l_eyebrow = get_patch(landmarks[17:22], color="yellow")
|
161 |
+
r_eyebrow = get_patch(landmarks[22:27], color="yellow")
|
162 |
+
nose_v = get_patch(landmarks[27:31], color="orange")
|
163 |
+
nose_h = get_patch(landmarks[31:36], color="orange")
|
164 |
+
l_eye = get_patch(landmarks[36:42], color="magenta", closed=True)
|
165 |
+
r_eye = get_patch(landmarks[42:48], color="magenta", closed=True)
|
166 |
+
outer_lips = get_patch(landmarks[48:60], color="cyan", closed=True)
|
167 |
+
inner_lips = get_patch(landmarks[60:68], color="blue", closed=True)
|
168 |
|
169 |
ax.add_patch(face_patch)
|
170 |
ax.add_patch(l_eyebrow)
|
|
|
176 |
ax.add_patch(outer_lips)
|
177 |
ax.add_patch(inner_lips)
|
178 |
|
179 |
+
plt.axis("off")
|
180 |
|
181 |
fig.canvas.draw()
|
182 |
buffer, (width, height) = fig.canvas.print_to_buffer()
|
|
|
193 |
landmarks = spiga
|
194 |
spiga_seg = conditioning_from_landmarks(landmarks, size=size)
|
195 |
return spiga_seg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|