Haiyu Wu
commited on
Commit
·
918e8a0
1
Parent(s):
ae82d2a
vec2face demo
Browse files- app.py +247 -0
- configs/vec2face/vqgan.yaml +16 -0
- models/__init__.py +1 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/iresnet.cpython-38.pyc +0 -0
- models/iresnet.py +150 -0
- pixel_generator/vec2face/__pycache__/im_decoder.cpython-38.pyc +0 -0
- pixel_generator/vec2face/__pycache__/model_vec2face.cpython-38.pyc +0 -0
- pixel_generator/vec2face/im_decoder.py +209 -0
- pixel_generator/vec2face/model_vec2face.py +357 -0
- pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-37.pyc +0 -0
- pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-38.pyc +0 -0
- pixel_generator/vec2face/taming/models/vqgan.py +67 -0
- pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-38.pyc +0 -0
- pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-39.pyc +0 -0
- pixel_generator/vec2face/taming/modules/discriminator/__pycache__/model.cpython-38.pyc +0 -0
- pixel_generator/vec2face/taming/modules/discriminator/model.py +113 -0
- pixel_generator/vec2face/taming/modules/discriminator_loss.py +128 -0
- pixel_generator/vec2face/taming/modules/util.py +130 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
import gradio as gr
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from models import iresnet
|
9 |
+
from sixdrepnet.model import SixDRepNet
|
10 |
+
import pixel_generator.vec2face.model_vec2face as model_vec2face
|
11 |
+
MAX_SEED = np.iinfo(np.int32).max
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.4, 0.4, 0.2]):
|
16 |
+
row, col = base_vector.shape
|
17 |
+
norm = torch.norm(base_vector, 2, 1, True)
|
18 |
+
diff = []
|
19 |
+
for i, eps in enumerate(epsilons):
|
20 |
+
diff.append(np.random.normal(0, eps, (int(row * percentages[i]), col)))
|
21 |
+
diff = np.vstack(diff)
|
22 |
+
np.random.shuffle(diff)
|
23 |
+
diff = torch.tensor(diff)
|
24 |
+
generated_samples = base_vector + diff
|
25 |
+
generated_samples = generated_samples / torch.norm(generated_samples, 2, 1, True) * norm
|
26 |
+
return generated_samples
|
27 |
+
|
28 |
+
|
29 |
+
def initialize_models():
|
30 |
+
device = torch.device('cpu')
|
31 |
+
pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
|
32 |
+
id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
|
33 |
+
quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
|
34 |
+
generator_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/vec2face_generator.pth", local_dir="./")
|
35 |
+
generator = model_vec2face.__dict__["vec2face_vit_base_patch16"](mask_ratio_mu=0.15, mask_ratio_std=0.25,
|
36 |
+
mask_ratio_min=0.1, mask_ratio_max=0.5,
|
37 |
+
use_rep=True,
|
38 |
+
rep_dim=512,
|
39 |
+
rep_drop_prob=0.,
|
40 |
+
use_class_label=False)
|
41 |
+
generator = generator.to(device)
|
42 |
+
checkpoint = torch.load(generator_weights, map_location='cpu')
|
43 |
+
generator.load_state_dict(checkpoint['model_vec2face'])
|
44 |
+
generator.eval()
|
45 |
+
|
46 |
+
id_model = iresnet("100", fp16=True).to(device)
|
47 |
+
id_model.load_state_dict(torch.load(id_model_weights, map_location='cpu'))
|
48 |
+
id_model.eval()
|
49 |
+
|
50 |
+
quality_model = iresnet("100", fp16=True).to(device)
|
51 |
+
quality_model.load_state_dict(torch.load(quality_model_weights, map_location='cpu'))
|
52 |
+
quality_model.eval()
|
53 |
+
|
54 |
+
pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',
|
55 |
+
backbone_file='',
|
56 |
+
deploy=True,
|
57 |
+
pretrained=False
|
58 |
+
).to(device)
|
59 |
+
pose_model.load_state_dict(torch.load(pose_model_weights))
|
60 |
+
pose_model.eval()
|
61 |
+
|
62 |
+
return generator, id_model, pose_model, quality_model
|
63 |
+
|
64 |
+
|
65 |
+
def image_generation(input_image, quality, use_target_pose, pose, dimension):
|
66 |
+
generator, id_model, pose_model, quality_model = initialize_models()
|
67 |
+
|
68 |
+
generated_images = []
|
69 |
+
if input_image is None:
|
70 |
+
feature = np.random.normal(0, 1.0, (1, 512))
|
71 |
+
else:
|
72 |
+
input_image = np.transpose(input_image, (2, 0, 1))
|
73 |
+
input_image = torch.from_numpy(input_image).unsqueeze(0).float()
|
74 |
+
input_image.div_(255).sub_(0.5).div_(0.5)
|
75 |
+
feature = id_model(input_image).clone().detach().cpu().numpy()
|
76 |
+
|
77 |
+
if not use_target_pose:
|
78 |
+
features = []
|
79 |
+
norm = np.linalg.norm(feature, 2, 1, True)
|
80 |
+
for i in np.arange(0, 4.8, 0.8):
|
81 |
+
updated_feature = feature
|
82 |
+
updated_feature[0][dimension] = feature[0][dimension] + i
|
83 |
+
|
84 |
+
updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
|
85 |
+
|
86 |
+
features.append(updated_feature)
|
87 |
+
features = torch.tensor(np.vstack(features)).float()
|
88 |
+
if quality > 25:
|
89 |
+
images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
|
90 |
+
else:
|
91 |
+
_, _, images, *_ = generator(features)
|
92 |
+
else:
|
93 |
+
features = torch.repeat_interleave(torch.tensor(feature), 6, dim=0)
|
94 |
+
features = sample_nearby_vectors(features, [0.7], [1]).float()
|
95 |
+
if quality > 25 and pose > 20:
|
96 |
+
images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
|
97 |
+
q_target=quality, pose=pose, class_rep=features)
|
98 |
+
else:
|
99 |
+
_, _, images, *_ = generator(features)
|
100 |
+
|
101 |
+
images = ((images.permute(0, 2, 3, 1).detach().cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
|
102 |
+
for image in images:
|
103 |
+
generated_images.append(Image.fromarray(image))
|
104 |
+
return generated_images
|
105 |
+
|
106 |
+
|
107 |
+
def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose):
|
108 |
+
# Ensure all dimension numbers are within [0, 512)
|
109 |
+
num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
|
110 |
+
|
111 |
+
# Use the provided random seed
|
112 |
+
random.seed(random_seed)
|
113 |
+
np.random.seed(random_seed)
|
114 |
+
if image_input is None:
|
115 |
+
input_data = None
|
116 |
+
else:
|
117 |
+
# Process the uploaded image
|
118 |
+
input_data = Image.open(image_input)
|
119 |
+
input_data = np.array(input_data.resize((112, 112)))
|
120 |
+
|
121 |
+
generated_images = image_generation(input_data, target_quality, use_target_pose, target_pose, [num1, num2, num3, num4])
|
122 |
+
|
123 |
+
return generated_images
|
124 |
+
|
125 |
+
def select_image(value, images):
|
126 |
+
# Convert the float value (0 to 4) to an integer index (0 to 9)
|
127 |
+
index = int(value / 0.8)
|
128 |
+
return images[index]
|
129 |
+
|
130 |
+
def toggle_inputs(use_pose):
|
131 |
+
return [
|
132 |
+
gr.update(visible=use_pose, interactive=use_pose), # target_pose
|
133 |
+
gr.update(interactive=not use_pose), # num1
|
134 |
+
gr.update(interactive=not use_pose), # num2
|
135 |
+
gr.update(interactive=not use_pose), # num3
|
136 |
+
gr.update(interactive=not use_pose), # num4
|
137 |
+
]
|
138 |
+
|
139 |
+
|
140 |
+
def main():
|
141 |
+
with gr.Blocks() as demo:
|
142 |
+
title = r"""
|
143 |
+
<h1 align="center">Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors</h1>
|
144 |
+
"""
|
145 |
+
|
146 |
+
description = r"""
|
147 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/HaiyuWu/vec2face' target='_blank'><b>Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors</b></a>.<br>
|
148 |
+
|
149 |
+
How to use:<br>
|
150 |
+
1. Upload an image with a cropped face image or directly click <b>Submit</b> button, six images will be shown on the right.
|
151 |
+
2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
|
152 |
+
3. The output results will shown six results of dimension modification or pose images.
|
153 |
+
4. Since the demo is CPU-based, higher quality and larger pose need longer time to run.
|
154 |
+
5. Enjoy! 😊
|
155 |
+
"""
|
156 |
+
|
157 |
+
gr.Markdown(title)
|
158 |
+
gr.Markdown(description)
|
159 |
+
with gr.Row():
|
160 |
+
with gr.Column():
|
161 |
+
image_file = gr.Image(label="Upload an image (optional)", type="filepath")
|
162 |
+
|
163 |
+
gr.Markdown("""
|
164 |
+
## Dimension Modification
|
165 |
+
Enter the values for the dimensions you want to modify (0-511).
|
166 |
+
""")
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
|
170 |
+
num2 = gr.Number(label="Dimension 2", value=0, minimum=0, maximum=511, step=1)
|
171 |
+
num3 = gr.Number(label="Dimension 3", value=0, minimum=0, maximum=511, step=1)
|
172 |
+
num4 = gr.Number(label="Dimension 4", value=0, minimum=0, maximum=511, step=1)
|
173 |
+
|
174 |
+
random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
|
175 |
+
target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
use_target_pose = gr.Checkbox(label="Use Target Pose")
|
179 |
+
target_pose = gr.Slider(label="Target Pose", value=0, minimum=0, maximum=90, step=1, visible=False)
|
180 |
+
|
181 |
+
submit = gr.Button("Submit", variant="primary")
|
182 |
+
|
183 |
+
gr.Markdown("""
|
184 |
+
## Usage tips of Vec2Face
|
185 |
+
- Directly clicking "Submit" button will give you results from a randomly sampled vector.
|
186 |
+
- If you want to modify more dimensions, please write your own code. Code snippets in [Vec2Face repo](https://github.com/HaiyuWu/vec2face) might be helpful.
|
187 |
+
- If you want to create extreme pose image (e.g., >70), please do not set image quality larger than 27.
|
188 |
+
- <span style="color: red;">!</span> <span style="color: red;">!</span> <span style="color: red;">!</span> **Due to the limitation of SixDRepNet (pose estimator), pose editing results might be corrupted/incorrect. For better performance, you can integrade other pose estimators.** <span style="color: red;">!</span> <span style="color: red;">!</span> <span style="color: red;">!</span>
|
189 |
+
- For better experience, we suggest you to run code on a GPU machine.
|
190 |
+
""")
|
191 |
+
|
192 |
+
with gr.Column():
|
193 |
+
gallery = gr.Image(label="Generated Image")
|
194 |
+
incremental_value_slider = gr.Slider(
|
195 |
+
label="Result of dimension modification or results of pose images",
|
196 |
+
minimum=0, maximum=4, step=0.8, value=0
|
197 |
+
)
|
198 |
+
gr.Markdown("""
|
199 |
+
- These values are added to the dimensions (before normalization), **please ignore it if pose editing is on**.
|
200 |
+
""")
|
201 |
+
|
202 |
+
use_target_pose.change(
|
203 |
+
fn=toggle_inputs,
|
204 |
+
inputs=[use_target_pose],
|
205 |
+
outputs=[target_pose, num1, num2, num3, num4]
|
206 |
+
)
|
207 |
+
|
208 |
+
generated_images = gr.State([])
|
209 |
+
|
210 |
+
submit.click(
|
211 |
+
fn=process_input,
|
212 |
+
inputs=[image_file, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose],
|
213 |
+
outputs=[generated_images]
|
214 |
+
).then(
|
215 |
+
fn=select_image,
|
216 |
+
inputs=[incremental_value_slider, generated_images],
|
217 |
+
outputs=[gallery]
|
218 |
+
)
|
219 |
+
|
220 |
+
incremental_value_slider.change(
|
221 |
+
fn=select_image,
|
222 |
+
inputs=[incremental_value_slider, generated_images],
|
223 |
+
outputs=[gallery]
|
224 |
+
)
|
225 |
+
article = r"""
|
226 |
+
---
|
227 |
+
📝 **Citation**
|
228 |
+
<br>
|
229 |
+
If our work is helpful for your research or applications, please cite us via:
|
230 |
+
```bibtex
|
231 |
+
@article{wu2024vec2face,
|
232 |
+
title={Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors},
|
233 |
+
author={Wu, Haiyu and Singh, Jaskirat and Tian, Sicong and Zheng, Liang and Bowyer, Kevin W.},
|
234 |
+
year={2024}
|
235 |
+
}
|
236 |
+
```
|
237 |
+
📧 **Contact**
|
238 |
+
<br>
|
239 |
+
If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
|
240 |
+
"""
|
241 |
+
gr.Markdown(article)
|
242 |
+
|
243 |
+
demo.launch(share=True)
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
main()
|
configs/vec2face/vqgan.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: pixel_generator.vec2face.taming.models.vqgan.VQModel
|
3 |
+
params:
|
4 |
+
embed_dim: 256
|
5 |
+
n_embed: 1024
|
6 |
+
ddconfig:
|
7 |
+
double_z: False
|
8 |
+
z_channels: 256
|
9 |
+
resolution: 112
|
10 |
+
in_channels: 3
|
11 |
+
out_ch: 3
|
12 |
+
ch: 128
|
13 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
14 |
+
num_res_blocks: 2
|
15 |
+
attn_resolutions: [16]
|
16 |
+
dropout: 0.0
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .iresnet import iresnet
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (140 Bytes). View file
|
|
models/__pycache__/iresnet.cpython-38.pyc
ADDED
Binary file (4.21 kB). View file
|
|
models/iresnet.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
|
5 |
+
using_ckpt = False
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1):
|
9 |
+
"""3x3 convolution with padding"""
|
10 |
+
return nn.Conv2d(in_planes,
|
11 |
+
out_planes,
|
12 |
+
kernel_size=3,
|
13 |
+
stride=stride,
|
14 |
+
padding=1,
|
15 |
+
groups=groups,
|
16 |
+
bias=False)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return nn.Conv2d(in_planes,
|
22 |
+
out_planes,
|
23 |
+
kernel_size=1,
|
24 |
+
stride=stride,
|
25 |
+
bias=False)
|
26 |
+
|
27 |
+
|
28 |
+
class IBasicBlock(nn.Module):
|
29 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
30 |
+
super(IBasicBlock, self).__init__()
|
31 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
32 |
+
self.conv1 = conv3x3(inplanes, planes)
|
33 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
34 |
+
self.prelu = nn.PReLU(planes)
|
35 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
36 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
37 |
+
self.downsample = downsample
|
38 |
+
self.stride = stride
|
39 |
+
|
40 |
+
def forward_impl(self, x):
|
41 |
+
identity = x
|
42 |
+
out = self.bn1(x)
|
43 |
+
out = self.conv1(out)
|
44 |
+
out = self.bn2(out)
|
45 |
+
out = self.prelu(out)
|
46 |
+
out = self.conv2(out)
|
47 |
+
out = self.bn3(out)
|
48 |
+
if self.downsample is not None:
|
49 |
+
identity = self.downsample(x)
|
50 |
+
out += identity
|
51 |
+
return out
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.training and using_ckpt:
|
55 |
+
return checkpoint(self.forward_impl, x)
|
56 |
+
else:
|
57 |
+
return self.forward_impl(x)
|
58 |
+
|
59 |
+
|
60 |
+
class IResNet(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
block, layers, dropout=0.4, num_features=512, zero_init_residual=False,
|
63 |
+
groups=1, fp16=False):
|
64 |
+
super(IResNet, self).__init__()
|
65 |
+
self.extra_gflops = 0.0
|
66 |
+
self.fp16 = fp16
|
67 |
+
self.inplanes = 64
|
68 |
+
|
69 |
+
self.groups = groups
|
70 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
71 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
72 |
+
self.prelu = nn.PReLU(self.inplanes)
|
73 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
74 |
+
self.layer2 = self._make_layer(block,
|
75 |
+
128,
|
76 |
+
layers[1],
|
77 |
+
stride=2)
|
78 |
+
self.layer3 = self._make_layer(block,
|
79 |
+
256,
|
80 |
+
layers[2],
|
81 |
+
stride=2)
|
82 |
+
self.layer4 = self._make_layer(block,
|
83 |
+
512,
|
84 |
+
layers[3],
|
85 |
+
stride=2)
|
86 |
+
self.bn2 = nn.BatchNorm2d(512, eps=1e-05,)
|
87 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
88 |
+
self.fc = nn.Linear(512 * 7 * 7, num_features)
|
89 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
90 |
+
nn.init.constant_(self.features.weight, 1.0)
|
91 |
+
self.features.weight.requires_grad = False
|
92 |
+
|
93 |
+
for m in self.modules():
|
94 |
+
if isinstance(m, nn.Conv2d):
|
95 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
96 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
97 |
+
nn.init.constant_(m.weight, 1)
|
98 |
+
nn.init.constant_(m.bias, 0)
|
99 |
+
|
100 |
+
if zero_init_residual:
|
101 |
+
for m in self.modules():
|
102 |
+
if isinstance(m, IBasicBlock):
|
103 |
+
nn.init.constant_(m.bn2.weight, 0)
|
104 |
+
|
105 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
106 |
+
downsample = None
|
107 |
+
if stride != 1 or self.inplanes != planes:
|
108 |
+
downsample = nn.Sequential(
|
109 |
+
conv1x1(self.inplanes, planes, stride),
|
110 |
+
nn.BatchNorm2d(planes, eps=1e-05, ),
|
111 |
+
)
|
112 |
+
layers = []
|
113 |
+
layers.append(
|
114 |
+
block(self.inplanes, planes, stride, downsample))
|
115 |
+
self.inplanes = planes
|
116 |
+
for _ in range(1, blocks):
|
117 |
+
layers.append(
|
118 |
+
block(self.inplanes,
|
119 |
+
planes))
|
120 |
+
|
121 |
+
return nn.Sequential(*layers)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
with torch.cuda.amp.autocast(self.fp16):
|
125 |
+
x = self.conv1(x)
|
126 |
+
x = self.bn1(x)
|
127 |
+
x = self.prelu(x)
|
128 |
+
x = self.layer1(x)
|
129 |
+
x = self.layer2(x)
|
130 |
+
x = self.layer3(x)
|
131 |
+
x = self.layer4(x)
|
132 |
+
x = self.bn2(x)
|
133 |
+
x = torch.flatten(x, 1)
|
134 |
+
x = self.dropout(x)
|
135 |
+
x = self.fc(x.float() if self.fp16 else x)
|
136 |
+
x = self.features(x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
def iresnet(arch, pretrained=False, **kwargs):
|
141 |
+
layer_dict = {"18": [2, 2, 2, 2],
|
142 |
+
"34": [3, 4, 6, 3],
|
143 |
+
"50": [3, 4, 14, 3],
|
144 |
+
"100": [3, 13, 30, 3],
|
145 |
+
"152": [3, 8, 36, 3],
|
146 |
+
"200": [3, 13, 30, 3]}
|
147 |
+
model = IResNet(IBasicBlock, layer_dict[arch], **kwargs)
|
148 |
+
if pretrained:
|
149 |
+
raise ValueError()
|
150 |
+
return model
|
pixel_generator/vec2face/__pycache__/im_decoder.cpython-38.pyc
ADDED
Binary file (4.74 kB). View file
|
|
pixel_generator/vec2face/__pycache__/model_vec2face.cpython-38.pyc
ADDED
Binary file (12.3 kB). View file
|
|
pixel_generator/vec2face/im_decoder.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def nonlinearity(x):
|
6 |
+
# swish
|
7 |
+
return x*torch.sigmoid(x)
|
8 |
+
|
9 |
+
|
10 |
+
def Normalize(in_channels):
|
11 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
12 |
+
|
13 |
+
|
14 |
+
class Upsample(nn.Module):
|
15 |
+
def __init__(self, in_channels, with_conv):
|
16 |
+
super().__init__()
|
17 |
+
self.with_conv = with_conv
|
18 |
+
if self.with_conv:
|
19 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
20 |
+
in_channels,
|
21 |
+
kernel_size=3,
|
22 |
+
stride=1,
|
23 |
+
padding=1)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
27 |
+
if self.with_conv:
|
28 |
+
x = self.conv(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Downsample(nn.Module):
|
33 |
+
def __init__(self, in_channels, with_conv):
|
34 |
+
super().__init__()
|
35 |
+
self.with_conv = with_conv
|
36 |
+
if self.with_conv:
|
37 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
38 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
39 |
+
in_channels,
|
40 |
+
kernel_size=3,
|
41 |
+
stride=2,
|
42 |
+
padding=0)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if self.with_conv:
|
46 |
+
pad = (0,1,0,1)
|
47 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
48 |
+
x = self.conv(x)
|
49 |
+
else:
|
50 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class ResnetBlock(nn.Module):
|
55 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
56 |
+
dropout, temb_channels=512):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
self.use_conv_shortcut = conv_shortcut
|
62 |
+
|
63 |
+
self.norm1 = Normalize(in_channels)
|
64 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
65 |
+
out_channels,
|
66 |
+
kernel_size=3,
|
67 |
+
stride=1,
|
68 |
+
padding=1,
|
69 |
+
bias=False)
|
70 |
+
if temb_channels > 0:
|
71 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
72 |
+
out_channels)
|
73 |
+
self.norm2 = Normalize(out_channels)
|
74 |
+
self.dropout = torch.nn.Dropout(dropout)
|
75 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
76 |
+
out_channels,
|
77 |
+
kernel_size=3,
|
78 |
+
stride=1,
|
79 |
+
padding=1,
|
80 |
+
bias=False)
|
81 |
+
if self.in_channels != self.out_channels:
|
82 |
+
if self.use_conv_shortcut:
|
83 |
+
self.conv_shortcut = torch.nn.Conv2d(out_channels,
|
84 |
+
out_channels,
|
85 |
+
kernel_size=3,
|
86 |
+
stride=1,
|
87 |
+
padding=1,
|
88 |
+
bias=False)
|
89 |
+
else:
|
90 |
+
self.nin_shortcut = torch.nn.Conv2d(out_channels,
|
91 |
+
out_channels,
|
92 |
+
kernel_size=1,
|
93 |
+
stride=1,
|
94 |
+
padding=0,
|
95 |
+
bias=False)
|
96 |
+
|
97 |
+
def forward(self, x, temb):
|
98 |
+
h = x
|
99 |
+
h = self.norm1(h)
|
100 |
+
h = nonlinearity(h)
|
101 |
+
h = self.conv1(h)
|
102 |
+
|
103 |
+
if temb is not None:
|
104 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
105 |
+
|
106 |
+
h = self.norm2(h)
|
107 |
+
h = nonlinearity(h)
|
108 |
+
h = self.dropout(h)
|
109 |
+
h = self.conv2(h)
|
110 |
+
|
111 |
+
if self.in_channels != self.out_channels:
|
112 |
+
if self.use_conv_shortcut:
|
113 |
+
x = self.conv_shortcut(h)
|
114 |
+
else:
|
115 |
+
x = self.nin_shortcut(h)
|
116 |
+
|
117 |
+
return x+h
|
118 |
+
|
119 |
+
|
120 |
+
class Decoder(nn.Module):
|
121 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
122 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
123 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
124 |
+
super().__init__()
|
125 |
+
self.ch = ch
|
126 |
+
self.temb_ch = 0
|
127 |
+
self.num_resolutions = len(ch_mult)
|
128 |
+
self.num_res_blocks = num_res_blocks
|
129 |
+
self.resolution = resolution
|
130 |
+
self.in_channels = in_channels
|
131 |
+
self.give_pre_end = give_pre_end
|
132 |
+
|
133 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
134 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
135 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
136 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
137 |
+
|
138 |
+
# z to block_in
|
139 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
140 |
+
block_in,
|
141 |
+
kernel_size=3,
|
142 |
+
stride=1,
|
143 |
+
padding=1)
|
144 |
+
|
145 |
+
# middle
|
146 |
+
self.mid = nn.Module()
|
147 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
148 |
+
out_channels=block_in,
|
149 |
+
temb_channels=self.temb_ch,
|
150 |
+
dropout=dropout)
|
151 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
152 |
+
out_channels=block_in,
|
153 |
+
temb_channels=self.temb_ch,
|
154 |
+
dropout=dropout)
|
155 |
+
|
156 |
+
# upsampling
|
157 |
+
self.up = nn.ModuleList()
|
158 |
+
for i_level in reversed(range(self.num_resolutions)):
|
159 |
+
block = nn.ModuleList()
|
160 |
+
block_out = ch*ch_mult[i_level]
|
161 |
+
for i_block in range(self.num_res_blocks):
|
162 |
+
block.append(ResnetBlock(in_channels=block_in,
|
163 |
+
out_channels=block_out,
|
164 |
+
temb_channels=self.temb_ch,
|
165 |
+
dropout=dropout))
|
166 |
+
block_in = block_out
|
167 |
+
up = nn.Module()
|
168 |
+
up.block = block
|
169 |
+
if i_level != 0:
|
170 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
171 |
+
curr_res = curr_res * 2
|
172 |
+
self.up.insert(0, up) # prepend to get consistent order
|
173 |
+
|
174 |
+
# end
|
175 |
+
self.norm_out = Normalize(block_in)
|
176 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
177 |
+
out_ch,
|
178 |
+
kernel_size=3,
|
179 |
+
stride=1,
|
180 |
+
padding=1)
|
181 |
+
|
182 |
+
def forward(self, z):
|
183 |
+
self.last_z_shape = z.shape
|
184 |
+
|
185 |
+
# timestep embedding
|
186 |
+
temb = None
|
187 |
+
|
188 |
+
# z to block_in
|
189 |
+
h = self.conv_in(z)
|
190 |
+
|
191 |
+
# middle
|
192 |
+
h = self.mid.block_1(h, temb)
|
193 |
+
h = self.mid.block_2(h, temb)
|
194 |
+
|
195 |
+
# upsampling
|
196 |
+
for i_level in reversed(range(self.num_resolutions)):
|
197 |
+
for i_block in range(self.num_res_blocks):
|
198 |
+
h = self.up[i_level].block[i_block](h, temb)
|
199 |
+
if i_level != 0:
|
200 |
+
h = self.up[i_level].upsample(h)
|
201 |
+
|
202 |
+
# end
|
203 |
+
if self.give_pre_end:
|
204 |
+
return h
|
205 |
+
|
206 |
+
h = self.norm_out(h)
|
207 |
+
h = nonlinearity(h)
|
208 |
+
h = self.conv_out(h)
|
209 |
+
return h
|
pixel_generator/vec2face/model_vec2face.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import numpy as np
|
8 |
+
import scipy.stats as stats
|
9 |
+
from pixel_generator.vec2face.im_decoder import Decoder
|
10 |
+
from sixdrepnet.model import utils
|
11 |
+
|
12 |
+
|
13 |
+
class Attention(nn.Module):
|
14 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
15 |
+
super().__init__()
|
16 |
+
self.num_heads = num_heads
|
17 |
+
head_dim = dim // num_heads
|
18 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
19 |
+
self.scale = qk_scale or head_dim ** -0.5
|
20 |
+
|
21 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
22 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
23 |
+
self.proj = nn.Linear(dim, dim)
|
24 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
B, N, C = x.shape
|
28 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
29 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
30 |
+
attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
|
31 |
+
attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
|
32 |
+
attn = attn.softmax(dim=-1)
|
33 |
+
attn = self.attn_drop(attn)
|
34 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
35 |
+
x = self.proj(x)
|
36 |
+
x = self.proj_drop(x)
|
37 |
+
return x, attn
|
38 |
+
|
39 |
+
|
40 |
+
class Block(nn.Module):
|
41 |
+
|
42 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
43 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
44 |
+
super().__init__()
|
45 |
+
self.norm1 = norm_layer(dim)
|
46 |
+
self.attn = Attention(
|
47 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
48 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
49 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
50 |
+
self.norm2 = norm_layer(dim)
|
51 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
52 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
53 |
+
|
54 |
+
def forward(self, x, return_attention=False):
|
55 |
+
with torch.cuda.amp.autocast(enabled=False):
|
56 |
+
if return_attention:
|
57 |
+
_, attn = self.attn(self.norm1(x))
|
58 |
+
return attn
|
59 |
+
else:
|
60 |
+
y, _ = self.attn(self.norm1(x))
|
61 |
+
x = x + self.drop_path(y)
|
62 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
67 |
+
""" NLL loss with label smoothing.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, smoothing=0.1):
|
71 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
72 |
+
assert smoothing < 1.0
|
73 |
+
self.smoothing = smoothing
|
74 |
+
self.confidence = 1. - smoothing
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
77 |
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
78 |
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
79 |
+
nll_loss = nll_loss.squeeze(1)
|
80 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
81 |
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
82 |
+
return loss
|
83 |
+
|
84 |
+
|
85 |
+
class BertEmbeddings(nn.Module):
|
86 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
87 |
+
|
88 |
+
def __init__(self, hidden_size, max_position_embeddings, dropout=0.1):
|
89 |
+
super().__init__()
|
90 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
91 |
+
|
92 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
93 |
+
# any TensorFlow checkpoint file
|
94 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
|
95 |
+
self.dropout = nn.Dropout(dropout)
|
96 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
97 |
+
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)))
|
98 |
+
|
99 |
+
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self, input_ids
|
103 |
+
):
|
104 |
+
input_shape = input_ids.size()
|
105 |
+
|
106 |
+
seq_length = input_shape[1]
|
107 |
+
|
108 |
+
position_ids = self.position_ids[:, :seq_length]
|
109 |
+
|
110 |
+
position_embeddings = self.position_embeddings(position_ids)
|
111 |
+
embeddings = input_ids + position_embeddings
|
112 |
+
|
113 |
+
embeddings = self.LayerNorm(embeddings)
|
114 |
+
embeddings = self.dropout(embeddings)
|
115 |
+
return embeddings
|
116 |
+
|
117 |
+
|
118 |
+
class MaskedGenerativeEncoderViT(nn.Module):
|
119 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, img_size=112, patch_size=7, in_chans=3,
|
123 |
+
embed_dim=1024, depth=24, num_heads=16,
|
124 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
125 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
|
126 |
+
mask_ratio_min=0.5, mask_ratio_max=1.0, mask_ratio_mu=0.55, mask_ratio_std=0.25,
|
127 |
+
use_rep=True, rep_dim=512,
|
128 |
+
rep_drop_prob=0.0,
|
129 |
+
use_class_label=False):
|
130 |
+
super().__init__()
|
131 |
+
assert not (use_rep and use_class_label)
|
132 |
+
|
133 |
+
# --------------------------------------------------------------------------
|
134 |
+
vqgan_config = OmegaConf.load('configs/vec2face/vqgan.yaml').model
|
135 |
+
self.token_emb = BertEmbeddings(hidden_size=embed_dim,
|
136 |
+
max_position_embeddings=49 + 1,
|
137 |
+
dropout=0.1)
|
138 |
+
self.use_rep = use_rep
|
139 |
+
self.use_class_label = use_class_label
|
140 |
+
if self.use_rep:
|
141 |
+
print("Use representation as condition!")
|
142 |
+
self.latent_prior_proj_f = nn.Linear(rep_dim, embed_dim, bias=True)
|
143 |
+
# CFG config
|
144 |
+
self.rep_drop_prob = rep_drop_prob
|
145 |
+
self.feature_token = nn.Linear(1, 49, bias=True)
|
146 |
+
self.center_token = nn.Linear(embed_dim, 49, bias=True)
|
147 |
+
self.im_decoder = Decoder(**vqgan_config.params.ddconfig)
|
148 |
+
self.im_decoder_proj = nn.Linear(embed_dim, vqgan_config.params.ddconfig.z_channels)
|
149 |
+
|
150 |
+
# Vec2Face variant masking ratio
|
151 |
+
self.mask_ratio_min = mask_ratio_min
|
152 |
+
self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
|
153 |
+
(mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
|
154 |
+
loc=mask_ratio_mu, scale=mask_ratio_std)
|
155 |
+
# --------------------------------------------------------------------------
|
156 |
+
# Vec2Face encoder specifics
|
157 |
+
dropout_rate = 0.1
|
158 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
159 |
+
num_patches = self.patch_embed.num_patches
|
160 |
+
|
161 |
+
self.blocks = nn.ModuleList([
|
162 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
163 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
164 |
+
for i in range(depth)])
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
|
167 |
+
# --------------------------------------------------------------------------
|
168 |
+
# Vec2Face decoder specifics
|
169 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
170 |
+
self.pad_with_cls_token = True
|
171 |
+
|
172 |
+
self.decoder_pos_embed_learned = nn.Parameter(
|
173 |
+
torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=True) # learnable pos embedding
|
174 |
+
|
175 |
+
self.decoder_blocks = nn.ModuleList([
|
176 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
177 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
178 |
+
for i in range(decoder_depth)])
|
179 |
+
|
180 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
181 |
+
# --------------------------------------------------------------------------
|
182 |
+
self.initialize_weights()
|
183 |
+
|
184 |
+
def initialize_weights(self):
|
185 |
+
w = self.patch_embed.proj.weight.data
|
186 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
187 |
+
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
188 |
+
torch.nn.init.xavier_uniform_(self.feature_token.weight)
|
189 |
+
torch.nn.init.xavier_uniform_(self.center_token.weight)
|
190 |
+
torch.nn.init.xavier_uniform_(self.latent_prior_proj_f.weight)
|
191 |
+
torch.nn.init.xavier_uniform_(self.decoder_embed.weight)
|
192 |
+
self.apply(self._init_weights)
|
193 |
+
|
194 |
+
def _init_weights(self, m):
|
195 |
+
if isinstance(m, nn.Linear):
|
196 |
+
# we use xavier_uniform following official JAX ViT:
|
197 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
198 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
199 |
+
nn.init.constant_(m.bias, 0)
|
200 |
+
elif isinstance(m, nn.LayerNorm):
|
201 |
+
nn.init.constant_(m.bias, 0)
|
202 |
+
nn.init.constant_(m.weight, 1.0)
|
203 |
+
|
204 |
+
def forward_encoder(self, rep):
|
205 |
+
# expand to feature map
|
206 |
+
device = rep.device
|
207 |
+
encode_feature = self.latent_prior_proj_f(rep)
|
208 |
+
feature_token = self.feature_token(encode_feature.unsqueeze(-1)).permute(0, 2, 1)
|
209 |
+
|
210 |
+
gt_indices = torch.cat((encode_feature.unsqueeze(1), feature_token), dim=1).clone().detach()
|
211 |
+
|
212 |
+
# masked row indices
|
213 |
+
bsz, seq_len, _ = feature_token.size()
|
214 |
+
mask_ratio_min = self.mask_ratio_min
|
215 |
+
mask_rate = self.mask_ratio_generator.rvs(1)[0]
|
216 |
+
|
217 |
+
num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
|
218 |
+
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
|
219 |
+
|
220 |
+
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
|
221 |
+
while True:
|
222 |
+
noise = torch.rand(bsz, seq_len, device=rep.device) # noise in [0, 1]
|
223 |
+
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
|
224 |
+
cutoff_drop = sorted_noise[:, num_dropped_tokens - 1:num_dropped_tokens]
|
225 |
+
cutoff_mask = sorted_noise[:, num_masked_tokens - 1:num_masked_tokens]
|
226 |
+
token_drop_mask = (noise <= cutoff_drop).float()
|
227 |
+
token_all_mask = (noise <= cutoff_mask).float()
|
228 |
+
if token_drop_mask.sum() == bsz * num_dropped_tokens and \
|
229 |
+
token_all_mask.sum() == bsz * num_masked_tokens:
|
230 |
+
break
|
231 |
+
else:
|
232 |
+
print("Rerandom the noise!")
|
233 |
+
token_all_mask_bool = token_all_mask.bool()
|
234 |
+
encode_feature_expanded = encode_feature.unsqueeze(1).expand(-1, feature_token.shape[1], -1)
|
235 |
+
feature_token[token_all_mask_bool] = encode_feature_expanded[token_all_mask_bool]
|
236 |
+
|
237 |
+
# concatenate with image feature
|
238 |
+
feature_token = torch.cat([encode_feature.unsqueeze(1), feature_token], dim=1)
|
239 |
+
token_drop_mask = torch.cat([torch.zeros(feature_token.size(0), 1).to(device), token_drop_mask], dim=1)
|
240 |
+
token_all_mask = torch.cat([torch.zeros(feature_token.size(0), 1).to(device), token_all_mask], dim=1)
|
241 |
+
|
242 |
+
# bert embedding
|
243 |
+
input_embeddings = self.token_emb(feature_token)
|
244 |
+
|
245 |
+
bsz, seq_len, emb_dim = input_embeddings.shape
|
246 |
+
|
247 |
+
# dropping
|
248 |
+
token_keep_mask = 1 - token_drop_mask
|
249 |
+
input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
|
250 |
+
|
251 |
+
# apply Transformer blocks
|
252 |
+
x = input_embeddings_after_drop
|
253 |
+
for blk in self.blocks:
|
254 |
+
x = blk(x)
|
255 |
+
x = self.norm(x)
|
256 |
+
return x, gt_indices, token_drop_mask, token_all_mask
|
257 |
+
|
258 |
+
def forward_decoder(self, x, token_drop_mask, token_all_mask):
|
259 |
+
# embed incomplete feature map
|
260 |
+
x = self.decoder_embed(x)
|
261 |
+
# fill masked positions with image feature
|
262 |
+
mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
|
263 |
+
x_after_pad = mask_tokens.clone()
|
264 |
+
x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
265 |
+
x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad)
|
266 |
+
# add pos embed
|
267 |
+
x = x_after_pad + self.decoder_pos_embed_learned
|
268 |
+
|
269 |
+
# apply Transformer blocks
|
270 |
+
for blk in self.decoder_blocks:
|
271 |
+
x = blk(x)
|
272 |
+
|
273 |
+
logits = self.decoder_norm(x)
|
274 |
+
bsz, _, emb_dim = logits.shape
|
275 |
+
# an image decoder
|
276 |
+
decoder_proj = self.im_decoder_proj(logits[:, 1:, :].reshape(bsz, 7, 7, emb_dim)).permute(0, 3, 1, 2)
|
277 |
+
return decoder_proj, logits
|
278 |
+
|
279 |
+
def get_last_layer(self):
|
280 |
+
return self.im_decoder.conv_out.weight
|
281 |
+
|
282 |
+
def forward(self, rep):
|
283 |
+
last_layer = self.get_last_layer()
|
284 |
+
latent, gt_indices, token_drop_mask, token_all_mask = self.forward_encoder(rep)
|
285 |
+
decoder_proj, logits = self.forward_decoder(latent, token_drop_mask, token_all_mask)
|
286 |
+
image = self.im_decoder(decoder_proj)
|
287 |
+
|
288 |
+
return gt_indices, logits, image, last_layer, token_all_mask
|
289 |
+
|
290 |
+
def gen_image(self, rep, quality_model, fr_model, pose_model=None, age_model=None, class_rep=None,
|
291 |
+
num_iter=1, lr=1e-1, q_target=27, pose=60):
|
292 |
+
rep_copy = rep.clone().detach().requires_grad_(True)
|
293 |
+
optm = optim.Adam([rep_copy], lr=lr)
|
294 |
+
|
295 |
+
i = 0
|
296 |
+
while i < num_iter:
|
297 |
+
latent, _, token_drop_mask, token_all_mask = self.forward_encoder(rep_copy)
|
298 |
+
decoder_proj, _ = self.forward_decoder(latent, token_drop_mask, token_all_mask)
|
299 |
+
image = self.im_decoder(decoder_proj).clip(max=1., min=-1.)
|
300 |
+
# feature comparison
|
301 |
+
out_feature = fr_model(image)
|
302 |
+
if class_rep is None:
|
303 |
+
id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
|
304 |
+
else:
|
305 |
+
distance = 1 - torch.cosine_similarity(out_feature, class_rep)
|
306 |
+
id_loss = torch.mean(torch.where(distance > 0.5, distance, torch.zeros_like(distance)))
|
307 |
+
quality = quality_model(image)
|
308 |
+
norm = torch.norm(quality, 2, 1, True)
|
309 |
+
q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
|
310 |
+
|
311 |
+
pose_loss = 0
|
312 |
+
if pose_model is not None:
|
313 |
+
# sixdrepnet
|
314 |
+
bgr_img = image[:, [2, 1, 0], :, :]
|
315 |
+
pose_info = pose_model(((bgr_img + 1) / 2))
|
316 |
+
pose_info = utils.compute_euler_angles_from_rotation_matrices(
|
317 |
+
pose_info) * 180 / np.pi
|
318 |
+
yaw_loss = torch.abs(pose - torch.abs(pose_info[:, 1].clip(min=-90, max=90)))
|
319 |
+
pose_loss = torch.mean(yaw_loss)
|
320 |
+
q_loss = torch.mean(q_loss)
|
321 |
+
if pose_loss > 5 or id_loss > 0.4 or q_loss > 1:
|
322 |
+
i -= 1
|
323 |
+
loss = id_loss * 100 + q_loss + pose_loss
|
324 |
+
optm.zero_grad()
|
325 |
+
loss.backward(retain_graph=True)
|
326 |
+
optm.step()
|
327 |
+
i += 1
|
328 |
+
|
329 |
+
latent, _, token_drop_mask, token_all_mask = self.forward_encoder(rep_copy)
|
330 |
+
decoder_proj, _ = self.forward_decoder(latent, token_drop_mask, token_all_mask)
|
331 |
+
image = self.im_decoder(decoder_proj).clip(max=1., min=-1.)
|
332 |
+
|
333 |
+
return image, rep_copy.detach()
|
334 |
+
|
335 |
+
|
336 |
+
def vec2face_vit_base_patch16(**kwargs):
|
337 |
+
model = MaskedGenerativeEncoderViT(
|
338 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
339 |
+
decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
|
340 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
341 |
+
return model
|
342 |
+
|
343 |
+
|
344 |
+
def vec2face_vit_large_patch16(**kwargs):
|
345 |
+
model = MaskedGenerativeEncoderViT(
|
346 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
347 |
+
decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
|
348 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
349 |
+
return model
|
350 |
+
|
351 |
+
|
352 |
+
def vec2face_vit_huge_patch16(**kwargs):
|
353 |
+
model = MaskedGenerativeEncoderViT(
|
354 |
+
patch_size=16, embed_dim=1280, depth=32, num_heads=16,
|
355 |
+
decoder_embed_dim=1280, decoder_depth=8, decoder_num_heads=16,
|
356 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
357 |
+
return model
|
pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-37.pyc
ADDED
Binary file (2.45 kB). View file
|
|
pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
pixel_generator/vec2face/taming/models/vqgan.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
|
4 |
+
from pixel_generator.mage.taming.modules.diffusionmodules.model import Encoder, Decoder
|
5 |
+
from pixel_generator.mage.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
6 |
+
|
7 |
+
|
8 |
+
class VQModel(pl.LightningModule):
|
9 |
+
def __init__(self,
|
10 |
+
ddconfig,
|
11 |
+
n_embed,
|
12 |
+
embed_dim,
|
13 |
+
ckpt_path=None,
|
14 |
+
ignore_keys=[],
|
15 |
+
image_key="image",
|
16 |
+
colorize_nlabels=None,
|
17 |
+
monitor=None,
|
18 |
+
remap=None,
|
19 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.image_key = image_key
|
23 |
+
self.encoder = Encoder(**ddconfig)
|
24 |
+
self.decoder = Decoder(**ddconfig)
|
25 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
26 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
27 |
+
if ckpt_path is not None:
|
28 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
29 |
+
self.image_key = image_key
|
30 |
+
if colorize_nlabels is not None:
|
31 |
+
assert type(colorize_nlabels)==int
|
32 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
33 |
+
if monitor is not None:
|
34 |
+
self.monitor = monitor
|
35 |
+
|
36 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
37 |
+
sd = torch.load(path, map_location="cpu")
|
38 |
+
if "state_dict" in sd.keys():
|
39 |
+
sd = sd["state_dict"]
|
40 |
+
keys = list(sd.keys())
|
41 |
+
for k in keys:
|
42 |
+
for ik in ignore_keys:
|
43 |
+
if k.startswith(ik):
|
44 |
+
print("Deleting key {} from state_dict.".format(k))
|
45 |
+
del sd[k]
|
46 |
+
print("Strict load")
|
47 |
+
self.load_state_dict(sd, strict=True)
|
48 |
+
print(f"Restored from {path}")
|
49 |
+
|
50 |
+
def encode(self, x):
|
51 |
+
h = self.encoder(x)
|
52 |
+
quant, emb_loss, info = self.quantize(h)
|
53 |
+
return quant, emb_loss, info
|
54 |
+
|
55 |
+
def decode(self, quant):
|
56 |
+
dec = self.decoder(quant)
|
57 |
+
return dec
|
58 |
+
|
59 |
+
def decode_code(self, code_b):
|
60 |
+
quant_b = self.quantize.embed_code(code_b)
|
61 |
+
dec = self.decode(quant_b)
|
62 |
+
return dec
|
63 |
+
|
64 |
+
def forward(self, input):
|
65 |
+
quant, diff, _ = self.encode(input)
|
66 |
+
dec = self.decode(quant)
|
67 |
+
return dec, diff
|
pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-38.pyc
ADDED
Binary file (4.47 kB). View file
|
|
pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-39.pyc
ADDED
Binary file (4.46 kB). View file
|
|
pixel_generator/vec2face/taming/modules/discriminator/__pycache__/model.cpython-38.pyc
ADDED
Binary file (4.1 kB). View file
|
|
pixel_generator/vec2face/taming/modules/discriminator/model.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
####################################ViT-VQGAN########################################
|
5 |
+
# https://github.com/lucidrains/parti-pytorch/blob/main/parti_pytorch/vit_vqgan.py#L171
|
6 |
+
#####################################################################################
|
7 |
+
def default(val, d):
|
8 |
+
return val if exists(val) else d
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
def leaky_relu(p = 0.1):
|
14 |
+
return nn.LeakyReLU(0.1)
|
15 |
+
|
16 |
+
class CrossEmbedLayer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim_in,
|
20 |
+
kernel_sizes,
|
21 |
+
dim_out = None,
|
22 |
+
stride = 2
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
26 |
+
dim_out = default(dim_out, dim_in)
|
27 |
+
|
28 |
+
kernel_sizes = sorted(kernel_sizes)
|
29 |
+
num_scales = len(kernel_sizes)
|
30 |
+
|
31 |
+
# calculate the dimension at each scale
|
32 |
+
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
33 |
+
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
34 |
+
|
35 |
+
self.convs = nn.ModuleList([])
|
36 |
+
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
37 |
+
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
41 |
+
return torch.cat(fmaps, dim = 1)
|
42 |
+
|
43 |
+
class Block(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
dim,
|
47 |
+
dim_out,
|
48 |
+
groups = 8
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.groupnorm = nn.GroupNorm(groups, dim)
|
52 |
+
self.activation = leaky_relu()
|
53 |
+
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
54 |
+
|
55 |
+
def forward(self, x, scale_shift = None):
|
56 |
+
x = self.groupnorm(x)
|
57 |
+
x = self.activation(x)
|
58 |
+
return self.project(x)
|
59 |
+
|
60 |
+
class ResnetBlock(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
dim,
|
64 |
+
dim_out = None,
|
65 |
+
*,
|
66 |
+
groups = 8
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
dim_out = default(dim_out, dim)
|
70 |
+
self.block = Block(dim, dim_out, groups = groups)
|
71 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
h = self.block(x)
|
75 |
+
return h + self.res_conv(x)
|
76 |
+
|
77 |
+
|
78 |
+
class Discriminator(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dims,
|
82 |
+
channels = 3,
|
83 |
+
groups = 8,
|
84 |
+
init_kernel_size = 5,
|
85 |
+
cross_embed_kernel_sizes = (3, 7, 15)
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
init_dim, *_, final_dim = dims
|
89 |
+
dim_pairs = zip(dims[:-1], dims[1:])
|
90 |
+
|
91 |
+
self.layers = nn.ModuleList([nn.Sequential(
|
92 |
+
CrossEmbedLayer(channels, cross_embed_kernel_sizes, init_dim, stride = 1),
|
93 |
+
leaky_relu()
|
94 |
+
)])
|
95 |
+
|
96 |
+
for dim_in, dim_out in dim_pairs:
|
97 |
+
self.layers.append(nn.Sequential(
|
98 |
+
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
99 |
+
leaky_relu(),
|
100 |
+
nn.GroupNorm(groups, dim_out),
|
101 |
+
ResnetBlock(dim_out, dim_out),
|
102 |
+
))
|
103 |
+
|
104 |
+
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
105 |
+
nn.Conv2d(final_dim, final_dim, 1),
|
106 |
+
leaky_relu(),
|
107 |
+
nn.Conv2d(final_dim, 1, 4)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
for net in self.layers:
|
112 |
+
x = net(x)
|
113 |
+
return self.to_logits(x)
|
pixel_generator/vec2face/taming/modules/discriminator_loss.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models import iresnet
|
5 |
+
from lpips.lpips import LPIPS
|
6 |
+
from pytorch_msssim import SSIM
|
7 |
+
|
8 |
+
|
9 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
10 |
+
if global_step < threshold:
|
11 |
+
weight = value
|
12 |
+
return weight
|
13 |
+
|
14 |
+
|
15 |
+
def hinge_d_loss(logits_real, logits_fake):
|
16 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
17 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
18 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
19 |
+
return d_loss
|
20 |
+
|
21 |
+
|
22 |
+
def mse_d_loss(logits_real, logits_fake):
|
23 |
+
loss_real = torch.mean((logits_real - 1.) ** 2)
|
24 |
+
loss_fake = torch.mean(logits_fake ** 2)
|
25 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
26 |
+
return d_loss
|
27 |
+
|
28 |
+
|
29 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
30 |
+
d_loss = 0.5 * (
|
31 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
32 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
33 |
+
return d_loss
|
34 |
+
|
35 |
+
|
36 |
+
def create_fr_model(model_path, depth="100"):
|
37 |
+
model = iresnet(depth)
|
38 |
+
model.load_state_dict(torch.load(model_path))
|
39 |
+
# model.half()
|
40 |
+
return model
|
41 |
+
|
42 |
+
|
43 |
+
def downscale(img: torch.tensor):
|
44 |
+
half_size = img.shape[-1] // 8
|
45 |
+
img = F.interpolate(img, size=(half_size, half_size), mode='bicubic', align_corners=False)
|
46 |
+
return img
|
47 |
+
|
48 |
+
|
49 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
50 |
+
def __init__(self, disc_start=1000, disc_factor=1.0, disc_weight=1.0,
|
51 |
+
disc_conditional=False, disc_loss="mse", id_loss="mse",
|
52 |
+
fr_model="./models/arcface-r100-glint360k.pth"):
|
53 |
+
super().__init__()
|
54 |
+
assert disc_loss in ["hinge", "vanilla", "mse", "smooth"]
|
55 |
+
self.loss_name = disc_loss
|
56 |
+
self.perceptual_loss = LPIPS().eval()
|
57 |
+
self.discriminator_iter_start = disc_start
|
58 |
+
if disc_loss == "hinge":
|
59 |
+
self.disc_loss = hinge_d_loss
|
60 |
+
elif disc_loss == "vanilla":
|
61 |
+
self.disc_loss = vanilla_d_loss
|
62 |
+
elif disc_loss == "mse":
|
63 |
+
self.disc_loss = mse_d_loss
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
66 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
67 |
+
self.fr_model = create_fr_model(fr_model).eval()
|
68 |
+
if id_loss == "mse":
|
69 |
+
self.feature_loss = nn.MSELoss()
|
70 |
+
elif id_loss == "cosine":
|
71 |
+
self.feature_loss = nn.CosineSimilarity()
|
72 |
+
self.disc_factor = disc_factor
|
73 |
+
self.discriminator_weight = disc_weight
|
74 |
+
self.disc_conditional = disc_conditional
|
75 |
+
self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
|
76 |
+
|
77 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
78 |
+
if last_layer is not None:
|
79 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
80 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
81 |
+
else:
|
82 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
83 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
84 |
+
|
85 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
86 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
87 |
+
d_weight = d_weight * self.discriminator_weight
|
88 |
+
return d_weight
|
89 |
+
|
90 |
+
def forward(self, im_features, gt_indices, logits, gt_img, image, discriminator, emb_loss,
|
91 |
+
epoch, last_layer=None, cond=None, mask=None):
|
92 |
+
rec_loss = (image - gt_img) ** 2
|
93 |
+
|
94 |
+
if epoch >= 0:
|
95 |
+
gen_feature = self.fr_model(image)
|
96 |
+
feature_loss = torch.mean(1 - torch.cosine_similarity(im_features, gen_feature))
|
97 |
+
else:
|
98 |
+
feature_loss = 0
|
99 |
+
|
100 |
+
p_loss = self.perceptual_loss(image, gt_img) * 2
|
101 |
+
|
102 |
+
with torch.cuda.amp.autocast(enabled=False):
|
103 |
+
ssim_loss = 1 - self.ssim_loss((image.float() + 1) / 2, (gt_img + 1) / 2)
|
104 |
+
logits_fake = discriminator(image)
|
105 |
+
logits_real_d = discriminator(gt_img.detach())
|
106 |
+
logits_fake_d = discriminator(image.detach())
|
107 |
+
|
108 |
+
if mask is None:
|
109 |
+
token_loss = (logits[:, 1:, :] - gt_indices[:, 1:, :])
|
110 |
+
token_loss = torch.mean(token_loss ** 2)
|
111 |
+
else:
|
112 |
+
token_loss = torch.abs((logits[:, 1:, :] - gt_indices[:, 1:, :])) * mask[:, 1:, None]
|
113 |
+
token_loss = token_loss.sum() / mask[:, 1:].sum()
|
114 |
+
# token_loss = 0
|
115 |
+
nll_loss = torch.mean(rec_loss + p_loss) + \
|
116 |
+
ssim_loss + \
|
117 |
+
token_loss + feature_loss + emb_loss
|
118 |
+
# generator update
|
119 |
+
g_loss = -torch.mean(logits_fake)
|
120 |
+
|
121 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
122 |
+
disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
|
123 |
+
ae_loss = nll_loss + d_weight * disc_factor * g_loss
|
124 |
+
|
125 |
+
# second pass for discriminator update
|
126 |
+
disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
|
127 |
+
d_loss = disc_factor * self.disc_loss(logits_real_d, logits_fake_d)
|
128 |
+
return ae_loss, d_loss, token_loss, rec_loss, ssim_loss, p_loss, feature_loss
|
pixel_generator/vec2face/taming/modules/util.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def count_params(model):
|
6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
7 |
+
return total_params
|
8 |
+
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class AbstractEncoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def encode(self, *args, **kwargs):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
|
103 |
+
class Labelator(AbstractEncoder):
|
104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
106 |
+
super().__init__()
|
107 |
+
self.n_classes = n_classes
|
108 |
+
self.quantize_interface = quantize_interface
|
109 |
+
|
110 |
+
def encode(self, c):
|
111 |
+
c = c[:,None]
|
112 |
+
if self.quantize_interface:
|
113 |
+
return c, None, [None, None, c.long()]
|
114 |
+
return c
|
115 |
+
|
116 |
+
|
117 |
+
class SOSProvider(AbstractEncoder):
|
118 |
+
# for unconditional training
|
119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
120 |
+
super().__init__()
|
121 |
+
self.sos_token = sos_token
|
122 |
+
self.quantize_interface = quantize_interface
|
123 |
+
|
124 |
+
def encode(self, x):
|
125 |
+
# get batch size from data and replicate sos_token
|
126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
+
c = c.long().to(x.device)
|
128 |
+
if self.quantize_interface:
|
129 |
+
return c, None, [None, None, c]
|
130 |
+
return c
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.0
|
2 |
+
numpy==1.24.3
|
3 |
+
torchvision==0.13.0
|
4 |
+
imageio==2.9.0
|
5 |
+
omegaconf==2.1.1
|
6 |
+
scipy==1.10.1
|
7 |
+
sixdrepnet==0.1.6
|
8 |
+
timm==0.9.16
|
9 |
+
gradio==4.42.0
|
10 |
+
huggingface-hub==0.24.6
|