Haiyu Wu
commited on
Commit
·
8acfda2
1
Parent(s):
42cf75c
update
Browse files
app.py
CHANGED
@@ -13,6 +13,8 @@ import torch
|
|
13 |
from time import time
|
14 |
|
15 |
|
|
|
|
|
16 |
def clear_image():
|
17 |
return None
|
18 |
|
@@ -44,7 +46,6 @@ def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.
|
|
44 |
|
45 |
|
46 |
def initialize_models():
|
47 |
-
device = torch.device('cpu')
|
48 |
pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
|
49 |
id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
|
50 |
quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
|
@@ -87,7 +88,7 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
|
|
87 |
feature = np.random.normal(0, 1.0, (1, 512))
|
88 |
else:
|
89 |
input_image = np.transpose(input_image, (2, 0, 1))
|
90 |
-
input_image = torch.from_numpy(input_image).unsqueeze(0).float()
|
91 |
input_image.div_(255).sub_(0.5).div_(0.5)
|
92 |
feature = id_model(input_image).clone().detach().cpu().numpy()
|
93 |
|
@@ -99,14 +100,14 @@ def image_generation(input_image, quality, use_target_pose, pose, dimension, pro
|
|
99 |
updated_feature[0][dimension] = feature[0][dimension] + i
|
100 |
updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
|
101 |
features.append(updated_feature)
|
102 |
-
features = torch.tensor(np.vstack(features)).float()
|
103 |
if quality > 25:
|
104 |
images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
|
105 |
else:
|
106 |
_, _, images, *_ = generator(features)
|
107 |
else:
|
108 |
features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
|
109 |
-
features = sample_nearby_vectors(features, [0.7], [1]).float()
|
110 |
if quality > 25 or pose > 20:
|
111 |
images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
|
112 |
q_target=quality, pose=pose, class_rep=features)
|
@@ -154,6 +155,7 @@ def toggle_inputs(use_pose):
|
|
154 |
]
|
155 |
|
156 |
|
|
|
157 |
def main():
|
158 |
with gr.Blocks() as demo:
|
159 |
title = r"""
|
@@ -167,8 +169,7 @@ def main():
|
|
167 |
1. Upload an image with a cropped face image or directly click <b>Submit</b> button, three images will be shown on the right.
|
168 |
2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
|
169 |
3. The output results will shown three results of dimension modification or pose images.
|
170 |
-
4.
|
171 |
-
5. Enjoy! 😊
|
172 |
"""
|
173 |
|
174 |
gr.Markdown(title)
|
|
|
13 |
from time import time
|
14 |
|
15 |
|
16 |
+
device = "cuda"
|
17 |
+
|
18 |
def clear_image():
|
19 |
return None
|
20 |
|
|
|
46 |
|
47 |
|
48 |
def initialize_models():
|
|
|
49 |
pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
|
50 |
id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
|
51 |
quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
|
|
|
88 |
feature = np.random.normal(0, 1.0, (1, 512))
|
89 |
else:
|
90 |
input_image = np.transpose(input_image, (2, 0, 1))
|
91 |
+
input_image = torch.from_numpy(input_image).unsqueeze(0).float().to(device)
|
92 |
input_image.div_(255).sub_(0.5).div_(0.5)
|
93 |
feature = id_model(input_image).clone().detach().cpu().numpy()
|
94 |
|
|
|
100 |
updated_feature[0][dimension] = feature[0][dimension] + i
|
101 |
updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
|
102 |
features.append(updated_feature)
|
103 |
+
features = torch.tensor(np.vstack(features)).float().to(device)
|
104 |
if quality > 25:
|
105 |
images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
|
106 |
else:
|
107 |
_, _, images, *_ = generator(features)
|
108 |
else:
|
109 |
features = torch.repeat_interleave(torch.tensor(feature), 3, dim=0)
|
110 |
+
features = sample_nearby_vectors(features, [0.7], [1]).float().to(device)
|
111 |
if quality > 25 or pose > 20:
|
112 |
images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
|
113 |
q_target=quality, pose=pose, class_rep=features)
|
|
|
155 |
]
|
156 |
|
157 |
|
158 |
+
# 4. Since the demo is CPU-based, higher quality and larger pose need longer time to run.
|
159 |
def main():
|
160 |
with gr.Blocks() as demo:
|
161 |
title = r"""
|
|
|
169 |
1. Upload an image with a cropped face image or directly click <b>Submit</b> button, three images will be shown on the right.
|
170 |
2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
|
171 |
3. The output results will shown three results of dimension modification or pose images.
|
172 |
+
4. Enjoy! 😊
|
|
|
173 |
"""
|
174 |
|
175 |
gr.Markdown(title)
|