Spaces:
Running
on
Zero
Running
on
Zero
Delete src/pixel3dmm/preprocessing/MICA/micalib/models/mica.py
Browse files
src/pixel3dmm/preprocessing/MICA/micalib/models/mica.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
-
# holder of all proprietary rights on this computer program.
|
5 |
-
# You can only use this computer program if you have closed
|
6 |
-
# a license agreement with MPG or you get the right to use the computer
|
7 |
-
# program from someone who is authorized to grant you that right.
|
8 |
-
# Any use of the computer program without a valid license is prohibited and
|
9 |
-
# liable to prosecution.
|
10 |
-
#
|
11 |
-
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
|
12 |
-
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
-
# for Intelligent Systems. All rights reserved.
|
14 |
-
#
|
15 |
-
# Contact: [email protected]
|
16 |
-
|
17 |
-
|
18 |
-
import os
|
19 |
-
import sys
|
20 |
-
|
21 |
-
sys.path.append("./nfclib")
|
22 |
-
|
23 |
-
import torch
|
24 |
-
import torch.nn.functional as F
|
25 |
-
|
26 |
-
from models.arcface import Arcface
|
27 |
-
from models.generator import Generator
|
28 |
-
from micalib.base_model import BaseModel
|
29 |
-
|
30 |
-
from loguru import logger
|
31 |
-
|
32 |
-
|
33 |
-
class MICA(BaseModel):
|
34 |
-
def __init__(self, config=None, device=None, tag='MICA'):
|
35 |
-
super(MICA, self).__init__(config, device, tag)
|
36 |
-
|
37 |
-
self.initialize()
|
38 |
-
|
39 |
-
def create_model(self, model_cfg):
|
40 |
-
mapping_layers = model_cfg.mapping_layers
|
41 |
-
pretrained_path = None
|
42 |
-
if not model_cfg.use_pretrained:
|
43 |
-
pretrained_path = model_cfg.arcface_pretrained_model
|
44 |
-
self.arcface = Arcface(pretrained_path=pretrained_path).to(self.device)
|
45 |
-
self.flameModel = Generator(512, 300, self.cfg.model.n_shape, mapping_layers, model_cfg, self.device)
|
46 |
-
|
47 |
-
def load_model(self):
|
48 |
-
model_path = os.path.join(self.cfg.output_dir, 'model.tar')
|
49 |
-
if os.path.exists(self.cfg.pretrained_model_path) and self.cfg.model.use_pretrained:
|
50 |
-
model_path = self.cfg.pretrained_model_path
|
51 |
-
if os.path.exists(model_path):
|
52 |
-
logger.info(f'[{self.tag}] Trained model found. Path: {model_path} | GPU: {self.device}')
|
53 |
-
checkpoint = torch.load(model_path)
|
54 |
-
if 'arcface' in checkpoint:
|
55 |
-
self.arcface.load_state_dict(checkpoint['arcface'])
|
56 |
-
if 'flameModel' in checkpoint:
|
57 |
-
self.flameModel.load_state_dict(checkpoint['flameModel'])
|
58 |
-
else:
|
59 |
-
logger.info(f'[{self.tag}] Checkpoint not available starting from scratch!')
|
60 |
-
|
61 |
-
def model_dict(self):
|
62 |
-
return {
|
63 |
-
'flameModel': self.flameModel.state_dict(),
|
64 |
-
'arcface': self.arcface.state_dict()
|
65 |
-
}
|
66 |
-
|
67 |
-
def parameters_to_optimize(self):
|
68 |
-
return [
|
69 |
-
{'params': self.flameModel.parameters(), 'lr': self.cfg.train.lr},
|
70 |
-
{'params': self.arcface.parameters(), 'lr': self.cfg.train.arcface_lr},
|
71 |
-
]
|
72 |
-
|
73 |
-
def encode(self, images, arcface_imgs):
|
74 |
-
codedict = {}
|
75 |
-
|
76 |
-
codedict['arcface'] = F.normalize(self.arcface(arcface_imgs))
|
77 |
-
codedict['images'] = images
|
78 |
-
|
79 |
-
return codedict
|
80 |
-
|
81 |
-
def decode(self, codedict, epoch=0):
|
82 |
-
self.epoch = epoch
|
83 |
-
|
84 |
-
flame_verts_shape = None
|
85 |
-
shapecode = None
|
86 |
-
|
87 |
-
if not self.testing:
|
88 |
-
flame = codedict['flame']
|
89 |
-
shapecode = flame['shape_params'].view(-1, flame['shape_params'].shape[2])
|
90 |
-
shapecode = shapecode.to(self.device)[:, :self.cfg.model.n_shape]
|
91 |
-
with torch.no_grad():
|
92 |
-
flame_verts_shape, _, _ = self.flame(shape_params=shapecode)
|
93 |
-
|
94 |
-
identity_code = codedict['arcface']
|
95 |
-
pred_canonical_vertices, pred_shape_code = self.flameModel(identity_code)
|
96 |
-
|
97 |
-
output = {
|
98 |
-
'flame_verts_shape': flame_verts_shape,
|
99 |
-
'flame_shape_code': shapecode,
|
100 |
-
'pred_canonical_shape_vertices': pred_canonical_vertices,
|
101 |
-
'pred_shape_code': pred_shape_code,
|
102 |
-
'faceid': codedict['arcface']
|
103 |
-
}
|
104 |
-
|
105 |
-
return output
|
106 |
-
|
107 |
-
def compute_losses(self, input, encoder_output, decoder_output):
|
108 |
-
losses = {}
|
109 |
-
|
110 |
-
pred_verts = decoder_output['pred_canonical_shape_vertices']
|
111 |
-
gt_verts = decoder_output['flame_verts_shape'].detach()
|
112 |
-
|
113 |
-
pred_verts_shape_canonical_diff = (pred_verts - gt_verts).abs()
|
114 |
-
|
115 |
-
if self.use_mask:
|
116 |
-
pred_verts_shape_canonical_diff *= self.vertices_mask
|
117 |
-
|
118 |
-
losses['pred_verts_shape_canonical_diff'] = torch.mean(pred_verts_shape_canonical_diff) * 1000.0
|
119 |
-
|
120 |
-
return losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|