Haiyu Wu commited on
Commit
918e8a0
·
1 Parent(s): ae82d2a

vec2face demo

Browse files
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import gradio as gr
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ from models import iresnet
9
+ from sixdrepnet.model import SixDRepNet
10
+ import pixel_generator.vec2face.model_vec2face as model_vec2face
11
+ MAX_SEED = np.iinfo(np.int32).max
12
+ import torch
13
+
14
+
15
+ def sample_nearby_vectors(base_vector, epsilons=[0.3, 0.5, 0.7], percentages=[0.4, 0.4, 0.2]):
16
+ row, col = base_vector.shape
17
+ norm = torch.norm(base_vector, 2, 1, True)
18
+ diff = []
19
+ for i, eps in enumerate(epsilons):
20
+ diff.append(np.random.normal(0, eps, (int(row * percentages[i]), col)))
21
+ diff = np.vstack(diff)
22
+ np.random.shuffle(diff)
23
+ diff = torch.tensor(diff)
24
+ generated_samples = base_vector + diff
25
+ generated_samples = generated_samples / torch.norm(generated_samples, 2, 1, True) * norm
26
+ return generated_samples
27
+
28
+
29
+ def initialize_models():
30
+ device = torch.device('cpu')
31
+ pose_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
32
+ id_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
33
+ quality_model_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
34
+ generator_weights = hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/vec2face_generator.pth", local_dir="./")
35
+ generator = model_vec2face.__dict__["vec2face_vit_base_patch16"](mask_ratio_mu=0.15, mask_ratio_std=0.25,
36
+ mask_ratio_min=0.1, mask_ratio_max=0.5,
37
+ use_rep=True,
38
+ rep_dim=512,
39
+ rep_drop_prob=0.,
40
+ use_class_label=False)
41
+ generator = generator.to(device)
42
+ checkpoint = torch.load(generator_weights, map_location='cpu')
43
+ generator.load_state_dict(checkpoint['model_vec2face'])
44
+ generator.eval()
45
+
46
+ id_model = iresnet("100", fp16=True).to(device)
47
+ id_model.load_state_dict(torch.load(id_model_weights, map_location='cpu'))
48
+ id_model.eval()
49
+
50
+ quality_model = iresnet("100", fp16=True).to(device)
51
+ quality_model.load_state_dict(torch.load(quality_model_weights, map_location='cpu'))
52
+ quality_model.eval()
53
+
54
+ pose_model = SixDRepNet(backbone_name='RepVGG-B1g2',
55
+ backbone_file='',
56
+ deploy=True,
57
+ pretrained=False
58
+ ).to(device)
59
+ pose_model.load_state_dict(torch.load(pose_model_weights))
60
+ pose_model.eval()
61
+
62
+ return generator, id_model, pose_model, quality_model
63
+
64
+
65
+ def image_generation(input_image, quality, use_target_pose, pose, dimension):
66
+ generator, id_model, pose_model, quality_model = initialize_models()
67
+
68
+ generated_images = []
69
+ if input_image is None:
70
+ feature = np.random.normal(0, 1.0, (1, 512))
71
+ else:
72
+ input_image = np.transpose(input_image, (2, 0, 1))
73
+ input_image = torch.from_numpy(input_image).unsqueeze(0).float()
74
+ input_image.div_(255).sub_(0.5).div_(0.5)
75
+ feature = id_model(input_image).clone().detach().cpu().numpy()
76
+
77
+ if not use_target_pose:
78
+ features = []
79
+ norm = np.linalg.norm(feature, 2, 1, True)
80
+ for i in np.arange(0, 4.8, 0.8):
81
+ updated_feature = feature
82
+ updated_feature[0][dimension] = feature[0][dimension] + i
83
+
84
+ updated_feature = updated_feature / np.linalg.norm(updated_feature, 2, 1, True) * norm
85
+
86
+ features.append(updated_feature)
87
+ features = torch.tensor(np.vstack(features)).float()
88
+ if quality > 25:
89
+ images, _ = generator.gen_image(features, quality_model, id_model, q_target=quality)
90
+ else:
91
+ _, _, images, *_ = generator(features)
92
+ else:
93
+ features = torch.repeat_interleave(torch.tensor(feature), 6, dim=0)
94
+ features = sample_nearby_vectors(features, [0.7], [1]).float()
95
+ if quality > 25 and pose > 20:
96
+ images, _ = generator.gen_image(features, quality_model, id_model, pose_model=pose_model,
97
+ q_target=quality, pose=pose, class_rep=features)
98
+ else:
99
+ _, _, images, *_ = generator(features)
100
+
101
+ images = ((images.permute(0, 2, 3, 1).detach().cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
102
+ for image in images:
103
+ generated_images.append(Image.fromarray(image))
104
+ return generated_images
105
+
106
+
107
+ def process_input(image_input, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose):
108
+ # Ensure all dimension numbers are within [0, 512)
109
+ num1, num2, num3, num4 = [max(0, min(int(n), 511)) for n in [num1, num2, num3, num4]]
110
+
111
+ # Use the provided random seed
112
+ random.seed(random_seed)
113
+ np.random.seed(random_seed)
114
+ if image_input is None:
115
+ input_data = None
116
+ else:
117
+ # Process the uploaded image
118
+ input_data = Image.open(image_input)
119
+ input_data = np.array(input_data.resize((112, 112)))
120
+
121
+ generated_images = image_generation(input_data, target_quality, use_target_pose, target_pose, [num1, num2, num3, num4])
122
+
123
+ return generated_images
124
+
125
+ def select_image(value, images):
126
+ # Convert the float value (0 to 4) to an integer index (0 to 9)
127
+ index = int(value / 0.8)
128
+ return images[index]
129
+
130
+ def toggle_inputs(use_pose):
131
+ return [
132
+ gr.update(visible=use_pose, interactive=use_pose), # target_pose
133
+ gr.update(interactive=not use_pose), # num1
134
+ gr.update(interactive=not use_pose), # num2
135
+ gr.update(interactive=not use_pose), # num3
136
+ gr.update(interactive=not use_pose), # num4
137
+ ]
138
+
139
+
140
+ def main():
141
+ with gr.Blocks() as demo:
142
+ title = r"""
143
+ <h1 align="center">Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors</h1>
144
+ """
145
+
146
+ description = r"""
147
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/HaiyuWu/vec2face' target='_blank'><b>Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors</b></a>.<br>
148
+
149
+ How to use:<br>
150
+ 1. Upload an image with a cropped face image or directly click <b>Submit</b> button, six images will be shown on the right.
151
+ 2. You can control the image quality, image pose, and modify the values in the target dimensions to change the output images.
152
+ 3. The output results will shown six results of dimension modification or pose images.
153
+ 4. Since the demo is CPU-based, higher quality and larger pose need longer time to run.
154
+ 5. Enjoy! 😊
155
+ """
156
+
157
+ gr.Markdown(title)
158
+ gr.Markdown(description)
159
+ with gr.Row():
160
+ with gr.Column():
161
+ image_file = gr.Image(label="Upload an image (optional)", type="filepath")
162
+
163
+ gr.Markdown("""
164
+ ## Dimension Modification
165
+ Enter the values for the dimensions you want to modify (0-511).
166
+ """)
167
+
168
+ with gr.Row():
169
+ num1 = gr.Number(label="Dimension 1", value=0, minimum=0, maximum=511, step=1)
170
+ num2 = gr.Number(label="Dimension 2", value=0, minimum=0, maximum=511, step=1)
171
+ num3 = gr.Number(label="Dimension 3", value=0, minimum=0, maximum=511, step=1)
172
+ num4 = gr.Number(label="Dimension 4", value=0, minimum=0, maximum=511, step=1)
173
+
174
+ random_seed = gr.Number(label="Random Seed", value=42, minimum=0, maximum=MAX_SEED, step=1)
175
+ target_quality = gr.Slider(label="Minimum Quality", minimum=22, maximum=35, step=1, value=24)
176
+
177
+ with gr.Row():
178
+ use_target_pose = gr.Checkbox(label="Use Target Pose")
179
+ target_pose = gr.Slider(label="Target Pose", value=0, minimum=0, maximum=90, step=1, visible=False)
180
+
181
+ submit = gr.Button("Submit", variant="primary")
182
+
183
+ gr.Markdown("""
184
+ ## Usage tips of Vec2Face
185
+ - Directly clicking "Submit" button will give you results from a randomly sampled vector.
186
+ - If you want to modify more dimensions, please write your own code. Code snippets in [Vec2Face repo](https://github.com/HaiyuWu/vec2face) might be helpful.
187
+ - If you want to create extreme pose image (e.g., >70), please do not set image quality larger than 27.
188
+ - <span style="color: red;">!</span> <span style="color: red;">!</span> <span style="color: red;">!</span> **Due to the limitation of SixDRepNet (pose estimator), pose editing results might be corrupted/incorrect. For better performance, you can integrade other pose estimators.** <span style="color: red;">!</span> <span style="color: red;">!</span> <span style="color: red;">!</span>
189
+ - For better experience, we suggest you to run code on a GPU machine.
190
+ """)
191
+
192
+ with gr.Column():
193
+ gallery = gr.Image(label="Generated Image")
194
+ incremental_value_slider = gr.Slider(
195
+ label="Result of dimension modification or results of pose images",
196
+ minimum=0, maximum=4, step=0.8, value=0
197
+ )
198
+ gr.Markdown("""
199
+ - These values are added to the dimensions (before normalization), **please ignore it if pose editing is on**.
200
+ """)
201
+
202
+ use_target_pose.change(
203
+ fn=toggle_inputs,
204
+ inputs=[use_target_pose],
205
+ outputs=[target_pose, num1, num2, num3, num4]
206
+ )
207
+
208
+ generated_images = gr.State([])
209
+
210
+ submit.click(
211
+ fn=process_input,
212
+ inputs=[image_file, num1, num2, num3, num4, random_seed, target_quality, use_target_pose, target_pose],
213
+ outputs=[generated_images]
214
+ ).then(
215
+ fn=select_image,
216
+ inputs=[incremental_value_slider, generated_images],
217
+ outputs=[gallery]
218
+ )
219
+
220
+ incremental_value_slider.change(
221
+ fn=select_image,
222
+ inputs=[incremental_value_slider, generated_images],
223
+ outputs=[gallery]
224
+ )
225
+ article = r"""
226
+ ---
227
+ 📝 **Citation**
228
+ <br>
229
+ If our work is helpful for your research or applications, please cite us via:
230
+ ```bibtex
231
+ @article{wu2024vec2face,
232
+ title={Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors},
233
+ author={Wu, Haiyu and Singh, Jaskirat and Tian, Sicong and Zheng, Liang and Bowyer, Kevin W.},
234
+ year={2024}
235
+ }
236
+ ```
237
+ 📧 **Contact**
238
+ <br>
239
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
240
+ """
241
+ gr.Markdown(article)
242
+
243
+ demo.launch(share=True)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ main()
configs/vec2face/vqgan.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: pixel_generator.vec2face.taming.models.vqgan.VQModel
3
+ params:
4
+ embed_dim: 256
5
+ n_embed: 1024
6
+ ddconfig:
7
+ double_z: False
8
+ z_channels: 256
9
+ resolution: 112
10
+ in_channels: 3
11
+ out_ch: 3
12
+ ch: 128
13
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
14
+ num_res_blocks: 2
15
+ attn_resolutions: [16]
16
+ dropout: 0.0
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .iresnet import iresnet
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (140 Bytes). View file
 
models/__pycache__/iresnet.cpython-38.pyc ADDED
Binary file (4.21 kB). View file
 
models/iresnet.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ using_ckpt = False
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1, groups=1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=1,
15
+ groups=groups,
16
+ bias=False)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
30
+ super(IBasicBlock, self).__init__()
31
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
32
+ self.conv1 = conv3x3(inplanes, planes)
33
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
34
+ self.prelu = nn.PReLU(planes)
35
+ self.conv2 = conv3x3(planes, planes, stride)
36
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
37
+ self.downsample = downsample
38
+ self.stride = stride
39
+
40
+ def forward_impl(self, x):
41
+ identity = x
42
+ out = self.bn1(x)
43
+ out = self.conv1(out)
44
+ out = self.bn2(out)
45
+ out = self.prelu(out)
46
+ out = self.conv2(out)
47
+ out = self.bn3(out)
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+ out += identity
51
+ return out
52
+
53
+ def forward(self, x):
54
+ if self.training and using_ckpt:
55
+ return checkpoint(self.forward_impl, x)
56
+ else:
57
+ return self.forward_impl(x)
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ def __init__(self,
62
+ block, layers, dropout=0.4, num_features=512, zero_init_residual=False,
63
+ groups=1, fp16=False):
64
+ super(IResNet, self).__init__()
65
+ self.extra_gflops = 0.0
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+
69
+ self.groups = groups
70
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
72
+ self.prelu = nn.PReLU(self.inplanes)
73
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
74
+ self.layer2 = self._make_layer(block,
75
+ 128,
76
+ layers[1],
77
+ stride=2)
78
+ self.layer3 = self._make_layer(block,
79
+ 256,
80
+ layers[2],
81
+ stride=2)
82
+ self.layer4 = self._make_layer(block,
83
+ 512,
84
+ layers[3],
85
+ stride=2)
86
+ self.bn2 = nn.BatchNorm2d(512, eps=1e-05,)
87
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
88
+ self.fc = nn.Linear(512 * 7 * 7, num_features)
89
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
90
+ nn.init.constant_(self.features.weight, 1.0)
91
+ self.features.weight.requires_grad = False
92
+
93
+ for m in self.modules():
94
+ if isinstance(m, nn.Conv2d):
95
+ nn.init.normal_(m.weight, 0, 0.1)
96
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
97
+ nn.init.constant_(m.weight, 1)
98
+ nn.init.constant_(m.bias, 0)
99
+
100
+ if zero_init_residual:
101
+ for m in self.modules():
102
+ if isinstance(m, IBasicBlock):
103
+ nn.init.constant_(m.bn2.weight, 0)
104
+
105
+ def _make_layer(self, block, planes, blocks, stride=1):
106
+ downsample = None
107
+ if stride != 1 or self.inplanes != planes:
108
+ downsample = nn.Sequential(
109
+ conv1x1(self.inplanes, planes, stride),
110
+ nn.BatchNorm2d(planes, eps=1e-05, ),
111
+ )
112
+ layers = []
113
+ layers.append(
114
+ block(self.inplanes, planes, stride, downsample))
115
+ self.inplanes = planes
116
+ for _ in range(1, blocks):
117
+ layers.append(
118
+ block(self.inplanes,
119
+ planes))
120
+
121
+ return nn.Sequential(*layers)
122
+
123
+ def forward(self, x):
124
+ with torch.cuda.amp.autocast(self.fp16):
125
+ x = self.conv1(x)
126
+ x = self.bn1(x)
127
+ x = self.prelu(x)
128
+ x = self.layer1(x)
129
+ x = self.layer2(x)
130
+ x = self.layer3(x)
131
+ x = self.layer4(x)
132
+ x = self.bn2(x)
133
+ x = torch.flatten(x, 1)
134
+ x = self.dropout(x)
135
+ x = self.fc(x.float() if self.fp16 else x)
136
+ x = self.features(x)
137
+ return x
138
+
139
+
140
+ def iresnet(arch, pretrained=False, **kwargs):
141
+ layer_dict = {"18": [2, 2, 2, 2],
142
+ "34": [3, 4, 6, 3],
143
+ "50": [3, 4, 14, 3],
144
+ "100": [3, 13, 30, 3],
145
+ "152": [3, 8, 36, 3],
146
+ "200": [3, 13, 30, 3]}
147
+ model = IResNet(IBasicBlock, layer_dict[arch], **kwargs)
148
+ if pretrained:
149
+ raise ValueError()
150
+ return model
pixel_generator/vec2face/__pycache__/im_decoder.cpython-38.pyc ADDED
Binary file (4.74 kB). View file
 
pixel_generator/vec2face/__pycache__/model_vec2face.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
pixel_generator/vec2face/im_decoder.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def nonlinearity(x):
6
+ # swish
7
+ return x*torch.sigmoid(x)
8
+
9
+
10
+ def Normalize(in_channels):
11
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
12
+
13
+
14
+ class Upsample(nn.Module):
15
+ def __init__(self, in_channels, with_conv):
16
+ super().__init__()
17
+ self.with_conv = with_conv
18
+ if self.with_conv:
19
+ self.conv = torch.nn.Conv2d(in_channels,
20
+ in_channels,
21
+ kernel_size=3,
22
+ stride=1,
23
+ padding=1)
24
+
25
+ def forward(self, x):
26
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
27
+ if self.with_conv:
28
+ x = self.conv(x)
29
+ return x
30
+
31
+
32
+ class Downsample(nn.Module):
33
+ def __init__(self, in_channels, with_conv):
34
+ super().__init__()
35
+ self.with_conv = with_conv
36
+ if self.with_conv:
37
+ # no asymmetric padding in torch conv, must do it ourselves
38
+ self.conv = torch.nn.Conv2d(in_channels,
39
+ in_channels,
40
+ kernel_size=3,
41
+ stride=2,
42
+ padding=0)
43
+
44
+ def forward(self, x):
45
+ if self.with_conv:
46
+ pad = (0,1,0,1)
47
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
48
+ x = self.conv(x)
49
+ else:
50
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
51
+ return x
52
+
53
+
54
+ class ResnetBlock(nn.Module):
55
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
56
+ dropout, temb_channels=512):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+ self.use_conv_shortcut = conv_shortcut
62
+
63
+ self.norm1 = Normalize(in_channels)
64
+ self.conv1 = torch.nn.Conv2d(in_channels,
65
+ out_channels,
66
+ kernel_size=3,
67
+ stride=1,
68
+ padding=1,
69
+ bias=False)
70
+ if temb_channels > 0:
71
+ self.temb_proj = torch.nn.Linear(temb_channels,
72
+ out_channels)
73
+ self.norm2 = Normalize(out_channels)
74
+ self.dropout = torch.nn.Dropout(dropout)
75
+ self.conv2 = torch.nn.Conv2d(out_channels,
76
+ out_channels,
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=1,
80
+ bias=False)
81
+ if self.in_channels != self.out_channels:
82
+ if self.use_conv_shortcut:
83
+ self.conv_shortcut = torch.nn.Conv2d(out_channels,
84
+ out_channels,
85
+ kernel_size=3,
86
+ stride=1,
87
+ padding=1,
88
+ bias=False)
89
+ else:
90
+ self.nin_shortcut = torch.nn.Conv2d(out_channels,
91
+ out_channels,
92
+ kernel_size=1,
93
+ stride=1,
94
+ padding=0,
95
+ bias=False)
96
+
97
+ def forward(self, x, temb):
98
+ h = x
99
+ h = self.norm1(h)
100
+ h = nonlinearity(h)
101
+ h = self.conv1(h)
102
+
103
+ if temb is not None:
104
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
105
+
106
+ h = self.norm2(h)
107
+ h = nonlinearity(h)
108
+ h = self.dropout(h)
109
+ h = self.conv2(h)
110
+
111
+ if self.in_channels != self.out_channels:
112
+ if self.use_conv_shortcut:
113
+ x = self.conv_shortcut(h)
114
+ else:
115
+ x = self.nin_shortcut(h)
116
+
117
+ return x+h
118
+
119
+
120
+ class Decoder(nn.Module):
121
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
122
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
123
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
124
+ super().__init__()
125
+ self.ch = ch
126
+ self.temb_ch = 0
127
+ self.num_resolutions = len(ch_mult)
128
+ self.num_res_blocks = num_res_blocks
129
+ self.resolution = resolution
130
+ self.in_channels = in_channels
131
+ self.give_pre_end = give_pre_end
132
+
133
+ # compute in_ch_mult, block_in and curr_res at lowest res
134
+ in_ch_mult = (1,)+tuple(ch_mult)
135
+ block_in = ch*ch_mult[self.num_resolutions-1]
136
+ curr_res = resolution // 2**(self.num_resolutions-1)
137
+
138
+ # z to block_in
139
+ self.conv_in = torch.nn.Conv2d(z_channels,
140
+ block_in,
141
+ kernel_size=3,
142
+ stride=1,
143
+ padding=1)
144
+
145
+ # middle
146
+ self.mid = nn.Module()
147
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
148
+ out_channels=block_in,
149
+ temb_channels=self.temb_ch,
150
+ dropout=dropout)
151
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
152
+ out_channels=block_in,
153
+ temb_channels=self.temb_ch,
154
+ dropout=dropout)
155
+
156
+ # upsampling
157
+ self.up = nn.ModuleList()
158
+ for i_level in reversed(range(self.num_resolutions)):
159
+ block = nn.ModuleList()
160
+ block_out = ch*ch_mult[i_level]
161
+ for i_block in range(self.num_res_blocks):
162
+ block.append(ResnetBlock(in_channels=block_in,
163
+ out_channels=block_out,
164
+ temb_channels=self.temb_ch,
165
+ dropout=dropout))
166
+ block_in = block_out
167
+ up = nn.Module()
168
+ up.block = block
169
+ if i_level != 0:
170
+ up.upsample = Upsample(block_in, resamp_with_conv)
171
+ curr_res = curr_res * 2
172
+ self.up.insert(0, up) # prepend to get consistent order
173
+
174
+ # end
175
+ self.norm_out = Normalize(block_in)
176
+ self.conv_out = torch.nn.Conv2d(block_in,
177
+ out_ch,
178
+ kernel_size=3,
179
+ stride=1,
180
+ padding=1)
181
+
182
+ def forward(self, z):
183
+ self.last_z_shape = z.shape
184
+
185
+ # timestep embedding
186
+ temb = None
187
+
188
+ # z to block_in
189
+ h = self.conv_in(z)
190
+
191
+ # middle
192
+ h = self.mid.block_1(h, temb)
193
+ h = self.mid.block_2(h, temb)
194
+
195
+ # upsampling
196
+ for i_level in reversed(range(self.num_resolutions)):
197
+ for i_block in range(self.num_res_blocks):
198
+ h = self.up[i_level].block[i_block](h, temb)
199
+ if i_level != 0:
200
+ h = self.up[i_level].upsample(h)
201
+
202
+ # end
203
+ if self.give_pre_end:
204
+ return h
205
+
206
+ h = self.norm_out(h)
207
+ h = nonlinearity(h)
208
+ h = self.conv_out(h)
209
+ return h
pixel_generator/vec2face/model_vec2face.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
6
+ from omegaconf import OmegaConf
7
+ import numpy as np
8
+ import scipy.stats as stats
9
+ from pixel_generator.vec2face.im_decoder import Decoder
10
+ from sixdrepnet.model import utils
11
+
12
+
13
+ class Attention(nn.Module):
14
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
15
+ super().__init__()
16
+ self.num_heads = num_heads
17
+ head_dim = dim // num_heads
18
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
19
+ self.scale = qk_scale or head_dim ** -0.5
20
+
21
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
22
+ self.attn_drop = nn.Dropout(attn_drop)
23
+ self.proj = nn.Linear(dim, dim)
24
+ self.proj_drop = nn.Dropout(proj_drop)
25
+
26
+ def forward(self, x):
27
+ B, N, C = x.shape
28
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
29
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
30
+ attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
31
+ attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
32
+ attn = attn.softmax(dim=-1)
33
+ attn = self.attn_drop(attn)
34
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
35
+ x = self.proj(x)
36
+ x = self.proj_drop(x)
37
+ return x, attn
38
+
39
+
40
+ class Block(nn.Module):
41
+
42
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
43
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
44
+ super().__init__()
45
+ self.norm1 = norm_layer(dim)
46
+ self.attn = Attention(
47
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
48
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
49
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
50
+ self.norm2 = norm_layer(dim)
51
+ mlp_hidden_dim = int(dim * mlp_ratio)
52
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
53
+
54
+ def forward(self, x, return_attention=False):
55
+ with torch.cuda.amp.autocast(enabled=False):
56
+ if return_attention:
57
+ _, attn = self.attn(self.norm1(x))
58
+ return attn
59
+ else:
60
+ y, _ = self.attn(self.norm1(x))
61
+ x = x + self.drop_path(y)
62
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
63
+ return x
64
+
65
+
66
+ class LabelSmoothingCrossEntropy(nn.Module):
67
+ """ NLL loss with label smoothing.
68
+ """
69
+
70
+ def __init__(self, smoothing=0.1):
71
+ super(LabelSmoothingCrossEntropy, self).__init__()
72
+ assert smoothing < 1.0
73
+ self.smoothing = smoothing
74
+ self.confidence = 1. - smoothing
75
+
76
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
77
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
78
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
79
+ nll_loss = nll_loss.squeeze(1)
80
+ smooth_loss = -logprobs.mean(dim=-1)
81
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
82
+ return loss
83
+
84
+
85
+ class BertEmbeddings(nn.Module):
86
+ """Construct the embeddings from word, position and token_type embeddings."""
87
+
88
+ def __init__(self, hidden_size, max_position_embeddings, dropout=0.1):
89
+ super().__init__()
90
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
91
+
92
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
93
+ # any TensorFlow checkpoint file
94
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
95
+ self.dropout = nn.Dropout(dropout)
96
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
97
+ self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1)))
98
+
99
+ torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
100
+
101
+ def forward(
102
+ self, input_ids
103
+ ):
104
+ input_shape = input_ids.size()
105
+
106
+ seq_length = input_shape[1]
107
+
108
+ position_ids = self.position_ids[:, :seq_length]
109
+
110
+ position_embeddings = self.position_embeddings(position_ids)
111
+ embeddings = input_ids + position_embeddings
112
+
113
+ embeddings = self.LayerNorm(embeddings)
114
+ embeddings = self.dropout(embeddings)
115
+ return embeddings
116
+
117
+
118
+ class MaskedGenerativeEncoderViT(nn.Module):
119
+ """ Masked Autoencoder with VisionTransformer backbone
120
+ """
121
+
122
+ def __init__(self, img_size=112, patch_size=7, in_chans=3,
123
+ embed_dim=1024, depth=24, num_heads=16,
124
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
125
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
126
+ mask_ratio_min=0.5, mask_ratio_max=1.0, mask_ratio_mu=0.55, mask_ratio_std=0.25,
127
+ use_rep=True, rep_dim=512,
128
+ rep_drop_prob=0.0,
129
+ use_class_label=False):
130
+ super().__init__()
131
+ assert not (use_rep and use_class_label)
132
+
133
+ # --------------------------------------------------------------------------
134
+ vqgan_config = OmegaConf.load('configs/vec2face/vqgan.yaml').model
135
+ self.token_emb = BertEmbeddings(hidden_size=embed_dim,
136
+ max_position_embeddings=49 + 1,
137
+ dropout=0.1)
138
+ self.use_rep = use_rep
139
+ self.use_class_label = use_class_label
140
+ if self.use_rep:
141
+ print("Use representation as condition!")
142
+ self.latent_prior_proj_f = nn.Linear(rep_dim, embed_dim, bias=True)
143
+ # CFG config
144
+ self.rep_drop_prob = rep_drop_prob
145
+ self.feature_token = nn.Linear(1, 49, bias=True)
146
+ self.center_token = nn.Linear(embed_dim, 49, bias=True)
147
+ self.im_decoder = Decoder(**vqgan_config.params.ddconfig)
148
+ self.im_decoder_proj = nn.Linear(embed_dim, vqgan_config.params.ddconfig.z_channels)
149
+
150
+ # Vec2Face variant masking ratio
151
+ self.mask_ratio_min = mask_ratio_min
152
+ self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
153
+ (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
154
+ loc=mask_ratio_mu, scale=mask_ratio_std)
155
+ # --------------------------------------------------------------------------
156
+ # Vec2Face encoder specifics
157
+ dropout_rate = 0.1
158
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
159
+ num_patches = self.patch_embed.num_patches
160
+
161
+ self.blocks = nn.ModuleList([
162
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
163
+ drop=dropout_rate, attn_drop=dropout_rate)
164
+ for i in range(depth)])
165
+ self.norm = norm_layer(embed_dim)
166
+
167
+ # --------------------------------------------------------------------------
168
+ # Vec2Face decoder specifics
169
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
170
+ self.pad_with_cls_token = True
171
+
172
+ self.decoder_pos_embed_learned = nn.Parameter(
173
+ torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=True) # learnable pos embedding
174
+
175
+ self.decoder_blocks = nn.ModuleList([
176
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
177
+ drop=dropout_rate, attn_drop=dropout_rate)
178
+ for i in range(decoder_depth)])
179
+
180
+ self.decoder_norm = norm_layer(decoder_embed_dim)
181
+ # --------------------------------------------------------------------------
182
+ self.initialize_weights()
183
+
184
+ def initialize_weights(self):
185
+ w = self.patch_embed.proj.weight.data
186
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
187
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
188
+ torch.nn.init.xavier_uniform_(self.feature_token.weight)
189
+ torch.nn.init.xavier_uniform_(self.center_token.weight)
190
+ torch.nn.init.xavier_uniform_(self.latent_prior_proj_f.weight)
191
+ torch.nn.init.xavier_uniform_(self.decoder_embed.weight)
192
+ self.apply(self._init_weights)
193
+
194
+ def _init_weights(self, m):
195
+ if isinstance(m, nn.Linear):
196
+ # we use xavier_uniform following official JAX ViT:
197
+ torch.nn.init.xavier_uniform_(m.weight)
198
+ if isinstance(m, nn.Linear) and m.bias is not None:
199
+ nn.init.constant_(m.bias, 0)
200
+ elif isinstance(m, nn.LayerNorm):
201
+ nn.init.constant_(m.bias, 0)
202
+ nn.init.constant_(m.weight, 1.0)
203
+
204
+ def forward_encoder(self, rep):
205
+ # expand to feature map
206
+ device = rep.device
207
+ encode_feature = self.latent_prior_proj_f(rep)
208
+ feature_token = self.feature_token(encode_feature.unsqueeze(-1)).permute(0, 2, 1)
209
+
210
+ gt_indices = torch.cat((encode_feature.unsqueeze(1), feature_token), dim=1).clone().detach()
211
+
212
+ # masked row indices
213
+ bsz, seq_len, _ = feature_token.size()
214
+ mask_ratio_min = self.mask_ratio_min
215
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
216
+
217
+ num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
218
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
219
+
220
+ # it is possible that two elements of the noise is the same, so do a while loop to avoid it
221
+ while True:
222
+ noise = torch.rand(bsz, seq_len, device=rep.device) # noise in [0, 1]
223
+ sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
224
+ cutoff_drop = sorted_noise[:, num_dropped_tokens - 1:num_dropped_tokens]
225
+ cutoff_mask = sorted_noise[:, num_masked_tokens - 1:num_masked_tokens]
226
+ token_drop_mask = (noise <= cutoff_drop).float()
227
+ token_all_mask = (noise <= cutoff_mask).float()
228
+ if token_drop_mask.sum() == bsz * num_dropped_tokens and \
229
+ token_all_mask.sum() == bsz * num_masked_tokens:
230
+ break
231
+ else:
232
+ print("Rerandom the noise!")
233
+ token_all_mask_bool = token_all_mask.bool()
234
+ encode_feature_expanded = encode_feature.unsqueeze(1).expand(-1, feature_token.shape[1], -1)
235
+ feature_token[token_all_mask_bool] = encode_feature_expanded[token_all_mask_bool]
236
+
237
+ # concatenate with image feature
238
+ feature_token = torch.cat([encode_feature.unsqueeze(1), feature_token], dim=1)
239
+ token_drop_mask = torch.cat([torch.zeros(feature_token.size(0), 1).to(device), token_drop_mask], dim=1)
240
+ token_all_mask = torch.cat([torch.zeros(feature_token.size(0), 1).to(device), token_all_mask], dim=1)
241
+
242
+ # bert embedding
243
+ input_embeddings = self.token_emb(feature_token)
244
+
245
+ bsz, seq_len, emb_dim = input_embeddings.shape
246
+
247
+ # dropping
248
+ token_keep_mask = 1 - token_drop_mask
249
+ input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
250
+
251
+ # apply Transformer blocks
252
+ x = input_embeddings_after_drop
253
+ for blk in self.blocks:
254
+ x = blk(x)
255
+ x = self.norm(x)
256
+ return x, gt_indices, token_drop_mask, token_all_mask
257
+
258
+ def forward_decoder(self, x, token_drop_mask, token_all_mask):
259
+ # embed incomplete feature map
260
+ x = self.decoder_embed(x)
261
+ # fill masked positions with image feature
262
+ mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
263
+ x_after_pad = mask_tokens.clone()
264
+ x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
265
+ x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad)
266
+ # add pos embed
267
+ x = x_after_pad + self.decoder_pos_embed_learned
268
+
269
+ # apply Transformer blocks
270
+ for blk in self.decoder_blocks:
271
+ x = blk(x)
272
+
273
+ logits = self.decoder_norm(x)
274
+ bsz, _, emb_dim = logits.shape
275
+ # an image decoder
276
+ decoder_proj = self.im_decoder_proj(logits[:, 1:, :].reshape(bsz, 7, 7, emb_dim)).permute(0, 3, 1, 2)
277
+ return decoder_proj, logits
278
+
279
+ def get_last_layer(self):
280
+ return self.im_decoder.conv_out.weight
281
+
282
+ def forward(self, rep):
283
+ last_layer = self.get_last_layer()
284
+ latent, gt_indices, token_drop_mask, token_all_mask = self.forward_encoder(rep)
285
+ decoder_proj, logits = self.forward_decoder(latent, token_drop_mask, token_all_mask)
286
+ image = self.im_decoder(decoder_proj)
287
+
288
+ return gt_indices, logits, image, last_layer, token_all_mask
289
+
290
+ def gen_image(self, rep, quality_model, fr_model, pose_model=None, age_model=None, class_rep=None,
291
+ num_iter=1, lr=1e-1, q_target=27, pose=60):
292
+ rep_copy = rep.clone().detach().requires_grad_(True)
293
+ optm = optim.Adam([rep_copy], lr=lr)
294
+
295
+ i = 0
296
+ while i < num_iter:
297
+ latent, _, token_drop_mask, token_all_mask = self.forward_encoder(rep_copy)
298
+ decoder_proj, _ = self.forward_decoder(latent, token_drop_mask, token_all_mask)
299
+ image = self.im_decoder(decoder_proj).clip(max=1., min=-1.)
300
+ # feature comparison
301
+ out_feature = fr_model(image)
302
+ if class_rep is None:
303
+ id_loss = torch.mean(1 - torch.cosine_similarity(out_feature, rep))
304
+ else:
305
+ distance = 1 - torch.cosine_similarity(out_feature, class_rep)
306
+ id_loss = torch.mean(torch.where(distance > 0.5, distance, torch.zeros_like(distance)))
307
+ quality = quality_model(image)
308
+ norm = torch.norm(quality, 2, 1, True)
309
+ q_loss = torch.where(norm < q_target, q_target - norm, torch.zeros_like(norm))
310
+
311
+ pose_loss = 0
312
+ if pose_model is not None:
313
+ # sixdrepnet
314
+ bgr_img = image[:, [2, 1, 0], :, :]
315
+ pose_info = pose_model(((bgr_img + 1) / 2))
316
+ pose_info = utils.compute_euler_angles_from_rotation_matrices(
317
+ pose_info) * 180 / np.pi
318
+ yaw_loss = torch.abs(pose - torch.abs(pose_info[:, 1].clip(min=-90, max=90)))
319
+ pose_loss = torch.mean(yaw_loss)
320
+ q_loss = torch.mean(q_loss)
321
+ if pose_loss > 5 or id_loss > 0.4 or q_loss > 1:
322
+ i -= 1
323
+ loss = id_loss * 100 + q_loss + pose_loss
324
+ optm.zero_grad()
325
+ loss.backward(retain_graph=True)
326
+ optm.step()
327
+ i += 1
328
+
329
+ latent, _, token_drop_mask, token_all_mask = self.forward_encoder(rep_copy)
330
+ decoder_proj, _ = self.forward_decoder(latent, token_drop_mask, token_all_mask)
331
+ image = self.im_decoder(decoder_proj).clip(max=1., min=-1.)
332
+
333
+ return image, rep_copy.detach()
334
+
335
+
336
+ def vec2face_vit_base_patch16(**kwargs):
337
+ model = MaskedGenerativeEncoderViT(
338
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
339
+ decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
340
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
341
+ return model
342
+
343
+
344
+ def vec2face_vit_large_patch16(**kwargs):
345
+ model = MaskedGenerativeEncoderViT(
346
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
347
+ decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
348
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
349
+ return model
350
+
351
+
352
+ def vec2face_vit_huge_patch16(**kwargs):
353
+ model = MaskedGenerativeEncoderViT(
354
+ patch_size=16, embed_dim=1280, depth=32, num_heads=16,
355
+ decoder_embed_dim=1280, decoder_depth=8, decoder_num_heads=16,
356
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
357
+ return model
pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-37.pyc ADDED
Binary file (2.45 kB). View file
 
pixel_generator/vec2face/taming/models/__pycache__/vqgan.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
pixel_generator/vec2face/taming/models/vqgan.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+
4
+ from pixel_generator.mage.taming.modules.diffusionmodules.model import Encoder, Decoder
5
+ from pixel_generator.mage.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
6
+
7
+
8
+ class VQModel(pl.LightningModule):
9
+ def __init__(self,
10
+ ddconfig,
11
+ n_embed,
12
+ embed_dim,
13
+ ckpt_path=None,
14
+ ignore_keys=[],
15
+ image_key="image",
16
+ colorize_nlabels=None,
17
+ monitor=None,
18
+ remap=None,
19
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
20
+ ):
21
+ super().__init__()
22
+ self.image_key = image_key
23
+ self.encoder = Encoder(**ddconfig)
24
+ self.decoder = Decoder(**ddconfig)
25
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
26
+ remap=remap, sane_index_shape=sane_index_shape)
27
+ if ckpt_path is not None:
28
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
29
+ self.image_key = image_key
30
+ if colorize_nlabels is not None:
31
+ assert type(colorize_nlabels)==int
32
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
33
+ if monitor is not None:
34
+ self.monitor = monitor
35
+
36
+ def init_from_ckpt(self, path, ignore_keys=list()):
37
+ sd = torch.load(path, map_location="cpu")
38
+ if "state_dict" in sd.keys():
39
+ sd = sd["state_dict"]
40
+ keys = list(sd.keys())
41
+ for k in keys:
42
+ for ik in ignore_keys:
43
+ if k.startswith(ik):
44
+ print("Deleting key {} from state_dict.".format(k))
45
+ del sd[k]
46
+ print("Strict load")
47
+ self.load_state_dict(sd, strict=True)
48
+ print(f"Restored from {path}")
49
+
50
+ def encode(self, x):
51
+ h = self.encoder(x)
52
+ quant, emb_loss, info = self.quantize(h)
53
+ return quant, emb_loss, info
54
+
55
+ def decode(self, quant):
56
+ dec = self.decoder(quant)
57
+ return dec
58
+
59
+ def decode_code(self, code_b):
60
+ quant_b = self.quantize.embed_code(code_b)
61
+ dec = self.decode(quant_b)
62
+ return dec
63
+
64
+ def forward(self, input):
65
+ quant, diff, _ = self.encode(input)
66
+ dec = self.decode(quant)
67
+ return dec, diff
pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
pixel_generator/vec2face/taming/modules/__pycache__/discriminator_loss.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
pixel_generator/vec2face/taming/modules/discriminator/__pycache__/model.cpython-38.pyc ADDED
Binary file (4.1 kB). View file
 
pixel_generator/vec2face/taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ ####################################ViT-VQGAN########################################
5
+ # https://github.com/lucidrains/parti-pytorch/blob/main/parti_pytorch/vit_vqgan.py#L171
6
+ #####################################################################################
7
+ def default(val, d):
8
+ return val if exists(val) else d
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def leaky_relu(p = 0.1):
14
+ return nn.LeakyReLU(0.1)
15
+
16
+ class CrossEmbedLayer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim_in,
20
+ kernel_sizes,
21
+ dim_out = None,
22
+ stride = 2
23
+ ):
24
+ super().__init__()
25
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
26
+ dim_out = default(dim_out, dim_in)
27
+
28
+ kernel_sizes = sorted(kernel_sizes)
29
+ num_scales = len(kernel_sizes)
30
+
31
+ # calculate the dimension at each scale
32
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
33
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
34
+
35
+ self.convs = nn.ModuleList([])
36
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
37
+ self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
38
+
39
+ def forward(self, x):
40
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
41
+ return torch.cat(fmaps, dim = 1)
42
+
43
+ class Block(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim,
47
+ dim_out,
48
+ groups = 8
49
+ ):
50
+ super().__init__()
51
+ self.groupnorm = nn.GroupNorm(groups, dim)
52
+ self.activation = leaky_relu()
53
+ self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
54
+
55
+ def forward(self, x, scale_shift = None):
56
+ x = self.groupnorm(x)
57
+ x = self.activation(x)
58
+ return self.project(x)
59
+
60
+ class ResnetBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ dim_out = None,
65
+ *,
66
+ groups = 8
67
+ ):
68
+ super().__init__()
69
+ dim_out = default(dim_out, dim)
70
+ self.block = Block(dim, dim_out, groups = groups)
71
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
72
+
73
+ def forward(self, x):
74
+ h = self.block(x)
75
+ return h + self.res_conv(x)
76
+
77
+
78
+ class Discriminator(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dims,
82
+ channels = 3,
83
+ groups = 8,
84
+ init_kernel_size = 5,
85
+ cross_embed_kernel_sizes = (3, 7, 15)
86
+ ):
87
+ super().__init__()
88
+ init_dim, *_, final_dim = dims
89
+ dim_pairs = zip(dims[:-1], dims[1:])
90
+
91
+ self.layers = nn.ModuleList([nn.Sequential(
92
+ CrossEmbedLayer(channels, cross_embed_kernel_sizes, init_dim, stride = 1),
93
+ leaky_relu()
94
+ )])
95
+
96
+ for dim_in, dim_out in dim_pairs:
97
+ self.layers.append(nn.Sequential(
98
+ nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
99
+ leaky_relu(),
100
+ nn.GroupNorm(groups, dim_out),
101
+ ResnetBlock(dim_out, dim_out),
102
+ ))
103
+
104
+ self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
105
+ nn.Conv2d(final_dim, final_dim, 1),
106
+ leaky_relu(),
107
+ nn.Conv2d(final_dim, 1, 4)
108
+ )
109
+
110
+ def forward(self, x):
111
+ for net in self.layers:
112
+ x = net(x)
113
+ return self.to_logits(x)
pixel_generator/vec2face/taming/modules/discriminator_loss.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models import iresnet
5
+ from lpips.lpips import LPIPS
6
+ from pytorch_msssim import SSIM
7
+
8
+
9
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
10
+ if global_step < threshold:
11
+ weight = value
12
+ return weight
13
+
14
+
15
+ def hinge_d_loss(logits_real, logits_fake):
16
+ loss_real = torch.mean(F.relu(1. - logits_real))
17
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
18
+ d_loss = 0.5 * (loss_real + loss_fake)
19
+ return d_loss
20
+
21
+
22
+ def mse_d_loss(logits_real, logits_fake):
23
+ loss_real = torch.mean((logits_real - 1.) ** 2)
24
+ loss_fake = torch.mean(logits_fake ** 2)
25
+ d_loss = 0.5 * (loss_real + loss_fake)
26
+ return d_loss
27
+
28
+
29
+ def vanilla_d_loss(logits_real, logits_fake):
30
+ d_loss = 0.5 * (
31
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
32
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
33
+ return d_loss
34
+
35
+
36
+ def create_fr_model(model_path, depth="100"):
37
+ model = iresnet(depth)
38
+ model.load_state_dict(torch.load(model_path))
39
+ # model.half()
40
+ return model
41
+
42
+
43
+ def downscale(img: torch.tensor):
44
+ half_size = img.shape[-1] // 8
45
+ img = F.interpolate(img, size=(half_size, half_size), mode='bicubic', align_corners=False)
46
+ return img
47
+
48
+
49
+ class VQLPIPSWithDiscriminator(nn.Module):
50
+ def __init__(self, disc_start=1000, disc_factor=1.0, disc_weight=1.0,
51
+ disc_conditional=False, disc_loss="mse", id_loss="mse",
52
+ fr_model="./models/arcface-r100-glint360k.pth"):
53
+ super().__init__()
54
+ assert disc_loss in ["hinge", "vanilla", "mse", "smooth"]
55
+ self.loss_name = disc_loss
56
+ self.perceptual_loss = LPIPS().eval()
57
+ self.discriminator_iter_start = disc_start
58
+ if disc_loss == "hinge":
59
+ self.disc_loss = hinge_d_loss
60
+ elif disc_loss == "vanilla":
61
+ self.disc_loss = vanilla_d_loss
62
+ elif disc_loss == "mse":
63
+ self.disc_loss = mse_d_loss
64
+ else:
65
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
66
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
67
+ self.fr_model = create_fr_model(fr_model).eval()
68
+ if id_loss == "mse":
69
+ self.feature_loss = nn.MSELoss()
70
+ elif id_loss == "cosine":
71
+ self.feature_loss = nn.CosineSimilarity()
72
+ self.disc_factor = disc_factor
73
+ self.discriminator_weight = disc_weight
74
+ self.disc_conditional = disc_conditional
75
+ self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
76
+
77
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
78
+ if last_layer is not None:
79
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
80
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
81
+ else:
82
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
83
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
84
+
85
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
86
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
87
+ d_weight = d_weight * self.discriminator_weight
88
+ return d_weight
89
+
90
+ def forward(self, im_features, gt_indices, logits, gt_img, image, discriminator, emb_loss,
91
+ epoch, last_layer=None, cond=None, mask=None):
92
+ rec_loss = (image - gt_img) ** 2
93
+
94
+ if epoch >= 0:
95
+ gen_feature = self.fr_model(image)
96
+ feature_loss = torch.mean(1 - torch.cosine_similarity(im_features, gen_feature))
97
+ else:
98
+ feature_loss = 0
99
+
100
+ p_loss = self.perceptual_loss(image, gt_img) * 2
101
+
102
+ with torch.cuda.amp.autocast(enabled=False):
103
+ ssim_loss = 1 - self.ssim_loss((image.float() + 1) / 2, (gt_img + 1) / 2)
104
+ logits_fake = discriminator(image)
105
+ logits_real_d = discriminator(gt_img.detach())
106
+ logits_fake_d = discriminator(image.detach())
107
+
108
+ if mask is None:
109
+ token_loss = (logits[:, 1:, :] - gt_indices[:, 1:, :])
110
+ token_loss = torch.mean(token_loss ** 2)
111
+ else:
112
+ token_loss = torch.abs((logits[:, 1:, :] - gt_indices[:, 1:, :])) * mask[:, 1:, None]
113
+ token_loss = token_loss.sum() / mask[:, 1:].sum()
114
+ # token_loss = 0
115
+ nll_loss = torch.mean(rec_loss + p_loss) + \
116
+ ssim_loss + \
117
+ token_loss + feature_loss + emb_loss
118
+ # generator update
119
+ g_loss = -torch.mean(logits_fake)
120
+
121
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
122
+ disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
123
+ ae_loss = nll_loss + d_weight * disc_factor * g_loss
124
+
125
+ # second pass for discriminator update
126
+ disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
127
+ d_loss = disc_factor * self.disc_loss(logits_real_d, logits_fake_d)
128
+ return ae_loss, d_loss, token_loss, rec_loss, ssim_loss, p_loss, feature_loss
pixel_generator/vec2face/taming/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.0
2
+ numpy==1.24.3
3
+ torchvision==0.13.0
4
+ imageio==2.9.0
5
+ omegaconf==2.1.1
6
+ scipy==1.10.1
7
+ sixdrepnet==0.1.6
8
+ timm==0.9.16
9
+ gradio==4.42.0
10
+ huggingface-hub==0.24.6