Haiyu Wu commited on
Commit
8acfda2
·
1 Parent(s): 42cf75c
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -13,6 +13,8 @@ import torch
13
  from time import time
14
 
15
 
 
 
16
  def clear_image():
17
  return None
18
 
@@ -44,7 +46,6 @@ def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.
44
 
45
 
46
  def initialize_models():
47
- device = torch.device('cpu')
48
  pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
49
  id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
50
  quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
@@ -87,7 +88,7 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
87
  feature = np.random.normal(0, 1.0, (1, 512))
88
  else:
89
  input_image = np.transpose(input_image, (2, 0, 1))
90
- input_image = torch.from_numpy(input_image).unsqueeze(0).float()
91
  input_image.div_(255).sub_(0.5).div_(0.5)
92
  feature = id_model(input_image).clone().detach().cpu().numpy()
93
 
@@ -99,14 +100,14 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
99
  updated_feature[0][dimension] = feature[0][dimension] + i
100
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
101
  features.append(updated_feature)
102
- features = torch.tensor(np.vstack(features)).float()
103
  if quality > 25:
104
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
105
  else:
106
  _, _, images, *_ = generator(features)
107
  else:
108
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
109
- features = sample_nearby_vectors(features, [0.7], [1]).float()
110
  if quality > 25 or pose > 20:
111
  images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
112
  q_target=quality, pose=pose, class_rep=features)
@@ -154,6 +155,7 @@ def toggle_inputs(use_pose):
154
  ]
155
 
156
 
 
157
  def main():
158
  with gr.Blocks() as demo:
159
  title = r"""
@@ -167,8 +169,7 @@ def main():
167
  1. Upload an image with a cropped face image or directly click <b>Submit</b> button, three images will be shown on the right.
168
  2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
169
  3. The output results will shown three results of dimension modification or pose images.
170
- 4. Since the demo is CPU-based, higher quality and larger pose need longer time to run.
171
- 5. Enjoy! 😊
172
  """
173
 
174
  gr.Markdown(title)
 
13
  from time import time
14
 
15
 
16
+ device = "cuda"
17
+
18
  def clear_image():
19
  return None
20
 
 
46
 
47
 
48
  def initialize_models():
 
49
  pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
50
  id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
51
  quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
 
88
  feature = np.random.normal(0, 1.0, (1, 512))
89
  else:
90
  input_image = np.transpose(input_image, (2, 0, 1))
91
+ input_image = torch.from_numpy(input_image).unsqueeze(0).float().to(device)
92
  input_image.div_(255).sub_(0.5).div_(0.5)
93
  feature = id_model(input_image).clone().detach().cpu().numpy()
94
 
 
100
  updated_feature[0][dimension] = feature[0][dimension] + i
101
  updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
102
  features.append(updated_feature)
103
+ features = torch.tensor(np.vstack(features)).float().to(device)
104
  if quality > 25:
105
  images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
106
  else:
107
  _, _, images, *_ = generator(features)
108
  else:
109
  features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
110
+ features = sample_nearby_vectors(features, [0.7], [1]).float().to(device)
111
  if quality > 25 or pose > 20:
112
  images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
113
  q_target=quality, pose=pose, class_rep=features)
 
155
  ]
156
 
157
 
158
+ # 4. Since the demo is CPU-based, higher quality and larger pose need longer time to run.
159
  def main():
160
  with gr.Blocks() as demo:
161
  title = r"""
 
169
  1. Upload an image with a cropped face image or directly click <b>Submit</b> button, three images will be shown on the right.
170
  2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
171
  3. The output results will shown three results of dimension modification or pose images.
172
+ 4. Enjoy! 😊
 
173
  """
174
 
175
  gr.Markdown(title)