sky24h commited on
Commit
9cbebfb
·
1 Parent(s): 53f2335

modify to support ZeroGPU

Browse files
Files changed (2) hide show
  1. inference_utils.py +4 -3
  2. spiga_draw.py +42 -66
inference_utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import torch
3
  import random
 
4
 
5
  seed = 1024
6
  random.seed(seed)
@@ -12,7 +14,6 @@ torch.backends.cudnn.benchmark = False
12
 
13
  from PIL import Image
14
  from gdown import download_folder
15
- from facelib import FaceDetector
16
  from spiga_draw import spiga_process, spiga_segmentation
17
 
18
  from pipeline_sd15 import StableDiffusionControlNetPipeline
@@ -20,11 +21,11 @@ from diffusers import DDIMScheduler, ControlNetModel
20
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
21
  from detail_encoder.encoder_plus import detail_encoder
22
 
23
- detector = FaceDetector(weight_path="./models/mobilenet0.25_Final.pth")
24
 
25
 
26
  def get_draw(pil_img, size):
27
- spigas = spiga_process(pil_img, detector)
 
28
  if spigas == False:
29
  width, height = pil_img.size
30
  black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
 
1
  import os
2
+ import cv2
3
  import torch
4
  import random
5
+ import numpy as np
6
 
7
  seed = 1024
8
  random.seed(seed)
 
14
 
15
  from PIL import Image
16
  from gdown import download_folder
 
17
  from spiga_draw import spiga_process, spiga_segmentation
18
 
19
  from pipeline_sd15 import StableDiffusionControlNetPipeline
 
21
  from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
22
  from detail_encoder.encoder_plus import detail_encoder
23
 
 
24
 
25
 
26
  def get_draw(pil_img, size):
27
+ cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
28
+ spigas = spiga_process(cv2_img)
29
  if spigas == False:
30
  width, height = pil_img.size
31
  black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
spiga_draw.py CHANGED
@@ -1,22 +1,26 @@
1
  import os
2
  import cv2
3
  import tqdm
 
4
  import numpy as np
5
  from PIL import Image
6
- from facelib import FaceDetector
7
  from spiga.inference.config import ModelConfig
8
  from spiga.inference.framework import SPIGAFramework
9
 
10
- # SPIGA ckpt downloading always fails, so we load it from the local path instead.
11
  spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
12
  if not os.path.exists(spiga_ckpt):
13
  from gdown import download
 
14
  spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
15
  download(id=spiga_file_id, output=spiga_ckpt)
16
  spiga_config = ModelConfig("300wpublic")
17
  spiga_config.load_model_url = False
18
  spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
19
  processor = SPIGAFramework(spiga_config)
 
 
20
 
21
  def center_crop(image, size):
22
  width, height = image.size
@@ -27,6 +31,7 @@ def center_crop(image, size):
27
  cropped_image = image.crop((left, top, right, bottom))
28
  return cropped_image
29
 
 
30
  def resize(image, size):
31
  width, height = image.size
32
  if width > height:
@@ -40,33 +45,32 @@ def resize(image, size):
40
  resized_image = image.resize((new_width, new_height))
41
  return resized_image
42
 
 
43
  def preprocess(example, name, path):
44
  image = resize(example, 512)
45
  # 调用中心剪裁函数
46
  cropped_image = center_crop(image, 512)
47
  # 保存剪裁后的图像
48
- cropped_image.save(path+name)
49
  return cropped_image
50
 
 
51
  # We obtain the bbox from the existing landmarks in the dataset.
52
  # We could use `dlib`, but this should be faster.
53
  # Note that the `landmarks` are stored as strings.
54
 
55
- def get_landmarks(image, detector):
56
- image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
57
- faces, boxes, scores, landmarks = detector.detect_align(image) # 一定要用align啊
58
- boxes = boxes.cpu().numpy()
59
- box_ls = []
60
- for box in boxes:
61
- x, y, x1, y1 = box
62
- box = x, y, x1 - x, y1 - y
63
- box_ls.append(box)
64
- if len(box_ls) == 0:
65
- return []
66
  else:
67
- features = processor.inference(image, box_ls)
68
- landmarks = np.array(features['landmarks'])
69
- return landmarks
 
 
 
70
 
71
 
72
  def parse_landmarks(landmarks):
@@ -88,29 +92,28 @@ def bbox_from_landmarks(landmarks_):
88
  height = y_max - y_min
89
 
90
  # Give it a little room; I think it works anyway
91
- x_min -= 5
92
- y_min -= 5
93
- width += 10
94
  height += 10
95
  bbox.append((x_min, y_min, width, height))
96
  return bbox
97
 
98
 
99
- def spiga_process(example, detector):
100
- ldms = get_landmarks(example, detector)
101
 
102
  if len(ldms) == 0:
103
  return False
104
 
105
  else:
106
- image = example
107
  image = np.array(image)
108
  # BGR
109
- image = image[:, :, ::-1]
110
- bbox = bbox_from_landmarks(ldms)
111
- features = processor.inference(image, [*bbox])
112
  landmarks = features["landmarks"]
113
- spigas = landmarks
114
  return spigas
115
 
116
 
@@ -128,7 +131,7 @@ from matplotlib.path import Path
128
  import PIL
129
 
130
 
131
- def get_patch(landmarks, color='lime', closed=False):
132
  contour = landmarks
133
  ops = [Path.MOVETO] + [Path.LINETO] * (len(contour) - 1)
134
  facecolor = (0, 0, 0, 0) # Transparent fill color, if open
@@ -142,10 +145,11 @@ def get_patch(landmarks, color='lime', closed=False):
142
 
143
  # Draw to a buffer.
144
 
 
145
  def conditioning_from_landmarks(landmarks_, size=512):
146
  # Precisely control output image size
147
  dpi = 72
148
- fig, ax = plt.subplots(1, figsize=[size / dpi, size / dpi], tight_layout={'pad': 0})
149
  fig.set_dpi(dpi)
150
 
151
  black = np.zeros((size, size, 3))
@@ -153,14 +157,14 @@ def conditioning_from_landmarks(landmarks_, size=512):
153
 
154
  for landmarks in landmarks_:
155
  face_patch = get_patch(landmarks[0:17])
156
- l_eyebrow = get_patch(landmarks[17:22], color='yellow')
157
- r_eyebrow = get_patch(landmarks[22:27], color='yellow')
158
- nose_v = get_patch(landmarks[27:31], color='orange')
159
- nose_h = get_patch(landmarks[31:36], color='orange')
160
- l_eye = get_patch(landmarks[36:42], color='magenta', closed=True)
161
- r_eye = get_patch(landmarks[42:48], color='magenta', closed=True)
162
- outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True)
163
- inner_lips = get_patch(landmarks[60:68], color='blue', closed=True)
164
 
165
  ax.add_patch(face_patch)
166
  ax.add_patch(l_eyebrow)
@@ -172,7 +176,7 @@ def conditioning_from_landmarks(landmarks_, size=512):
172
  ax.add_patch(outer_lips)
173
  ax.add_patch(inner_lips)
174
 
175
- plt.axis('off')
176
 
177
  fig.canvas.draw()
178
  buffer, (width, height) = fig.canvas.print_to_buffer()
@@ -189,31 +193,3 @@ def spiga_segmentation(spiga, size):
189
  landmarks = spiga
190
  spiga_seg = conditioning_from_landmarks(landmarks, size=size)
191
  return spiga_seg
192
-
193
-
194
- if __name__ == '__main__':
195
- # ## Obtain SPIGA features
196
- processor = SPIGAFramework(ModelConfig("300wpublic"))
197
- detector = FaceDetector(weight_path="/share2/zhangyuxuan/project/train_ip_cn/datasets/make_kps/pretrained_models/mobilenet0.25_Final.pth")
198
-
199
- id_folder = "/share2/zhangyuxuan/project/train_ip_cn/test_img_2/id/"
200
- pose_folder = "/share2/zhangyuxuan/project/train_ip_cn/test_img_2/pose/"
201
-
202
- if not os.path.exists(pose_folder):
203
- os.makedirs(pose_folder)
204
-
205
- pbar = tqdm.tqdm(os.listdir(id_folder))
206
- for name in pbar:
207
- face = Image.open(id_folder+name).convert("RGB").resize((512, 512))
208
- face.save(id_folder+name)
209
- spigas = spiga_process(face, detector)
210
- if spigas == False:
211
- height = 512
212
- width = 512
213
- channels = 3
214
- black_image = np.zeros((height, width, channels), dtype=np.uint8)
215
- black_image_cv2 = cv2.cvtColor(black_image, cv2.COLOR_RGB2BGR)
216
- continue
217
- else:
218
- spigas_faces = spiga_segmentation(spigas)
219
- spigas_faces.save(pose_folder + name)
 
1
  import os
2
  import cv2
3
  import tqdm
4
+ import torch
5
  import numpy as np
6
  from PIL import Image
7
+ from batch_face import RetinaFace
8
  from spiga.inference.config import ModelConfig
9
  from spiga.inference.framework import SPIGAFramework
10
 
11
+ # The SPIGA checkpoint download often fails, so we downloaded it manually and will load it from a local path instead.
12
  spiga_ckpt = os.path.join(os.path.dirname(__file__), "checkpoints/spiga_300wpublic.pt")
13
  if not os.path.exists(spiga_ckpt):
14
  from gdown import download
15
+
16
  spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
17
  download(id=spiga_file_id, output=spiga_ckpt)
18
  spiga_config = ModelConfig("300wpublic")
19
  spiga_config.load_model_url = False
20
  spiga_config.model_weights_path = os.path.dirname(spiga_ckpt)
21
  processor = SPIGAFramework(spiga_config)
22
+ face_detector = RetinaFace(gpu_id=0) if torch.cuda.is_available() else RetinaFace(gpu_id=-1)
23
+
24
 
25
  def center_crop(image, size):
26
  width, height = image.size
 
31
  cropped_image = image.crop((left, top, right, bottom))
32
  return cropped_image
33
 
34
+
35
  def resize(image, size):
36
  width, height = image.size
37
  if width > height:
 
45
  resized_image = image.resize((new_width, new_height))
46
  return resized_image
47
 
48
+
49
  def preprocess(example, name, path):
50
  image = resize(example, 512)
51
  # 调用中心剪裁函数
52
  cropped_image = center_crop(image, 512)
53
  # 保存剪裁后的图像
54
+ cropped_image.save(path + name)
55
  return cropped_image
56
 
57
+
58
  # We obtain the bbox from the existing landmarks in the dataset.
59
  # We could use `dlib`, but this should be faster.
60
  # Note that the `landmarks` are stored as strings.
61
 
62
+
63
+ def get_landmarks(frame_cv2):
64
+ faces = face_detector(frame_cv2, cv=True)
65
+ if len(faces) == 0:
66
+ raise ValueError("Face is not detected")
 
 
 
 
 
 
67
  else:
68
+ coord = faces[0][0]
69
+ x, y, x1, y1 = coord
70
+ box = x, y, x1 - x, y1 - y
71
+ features = processor.inference(frame_cv2, [box])
72
+ landmarks = np.array(features["landmarks"])
73
+ return landmarks
74
 
75
 
76
  def parse_landmarks(landmarks):
 
92
  height = y_max - y_min
93
 
94
  # Give it a little room; I think it works anyway
95
+ x_min -= 5
96
+ y_min -= 5
97
+ width += 10
98
  height += 10
99
  bbox.append((x_min, y_min, width, height))
100
  return bbox
101
 
102
 
103
+ def spiga_process(image):
104
+ ldms = get_landmarks(image)
105
 
106
  if len(ldms) == 0:
107
  return False
108
 
109
  else:
 
110
  image = np.array(image)
111
  # BGR
112
+ image = image[:, :, ::-1]
113
+ bbox = bbox_from_landmarks(ldms)
114
+ features = processor.inference(image, [*bbox])
115
  landmarks = features["landmarks"]
116
+ spigas = landmarks
117
  return spigas
118
 
119
 
 
131
  import PIL
132
 
133
 
134
+ def get_patch(landmarks, color="lime", closed=False):
135
  contour = landmarks
136
  ops = [Path.MOVETO] + [Path.LINETO] * (len(contour) - 1)
137
  facecolor = (0, 0, 0, 0) # Transparent fill color, if open
 
145
 
146
  # Draw to a buffer.
147
 
148
+
149
  def conditioning_from_landmarks(landmarks_, size=512):
150
  # Precisely control output image size
151
  dpi = 72
152
+ fig, ax = plt.subplots(1, figsize=[size / dpi, size / dpi], tight_layout={"pad": 0})
153
  fig.set_dpi(dpi)
154
 
155
  black = np.zeros((size, size, 3))
 
157
 
158
  for landmarks in landmarks_:
159
  face_patch = get_patch(landmarks[0:17])
160
+ l_eyebrow = get_patch(landmarks[17:22], color="yellow")
161
+ r_eyebrow = get_patch(landmarks[22:27], color="yellow")
162
+ nose_v = get_patch(landmarks[27:31], color="orange")
163
+ nose_h = get_patch(landmarks[31:36], color="orange")
164
+ l_eye = get_patch(landmarks[36:42], color="magenta", closed=True)
165
+ r_eye = get_patch(landmarks[42:48], color="magenta", closed=True)
166
+ outer_lips = get_patch(landmarks[48:60], color="cyan", closed=True)
167
+ inner_lips = get_patch(landmarks[60:68], color="blue", closed=True)
168
 
169
  ax.add_patch(face_patch)
170
  ax.add_patch(l_eyebrow)
 
176
  ax.add_patch(outer_lips)
177
  ax.add_patch(inner_lips)
178
 
179
+ plt.axis("off")
180
 
181
  fig.canvas.draw()
182
  buffer, (width, height) = fig.canvas.print_to_buffer()
 
193
  landmarks = spiga
194
  spiga_seg = conditioning_from_landmarks(landmarks, size=size)
195
  return spiga_seg