Spanicin commited on
Commit
a4906b2
·
verified ·
1 Parent(s): 40934d8

Delete src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py DELETED
@@ -1,221 +0,0 @@
1
- from scipy.spatial import ConvexHull
2
- import torch
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
8
- use_relative_movement=False, use_relative_jacobian=False):
9
- if adapt_movement_scale:
10
- source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
11
- driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
12
- adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
13
- else:
14
- adapt_movement_scale = 1
15
-
16
- kp_new = {k: v for k, v in kp_driving.items()}
17
-
18
- if use_relative_movement:
19
- kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
20
- kp_value_diff *= adapt_movement_scale
21
- kp_new['value'] = kp_value_diff + kp_source['value']
22
-
23
- if use_relative_jacobian:
24
- jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
25
- kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
26
-
27
- return kp_new
28
-
29
- def headpose_pred_to_degree(pred):
30
- device = pred.device
31
- idx_tensor = [idx for idx in range(66)]
32
- idx_tensor = torch.FloatTensor(idx_tensor).to(device)
33
- pred = F.softmax(pred)
34
- degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
35
- return degree
36
-
37
- def get_rotation_matrix(yaw, pitch, roll):
38
- yaw = yaw / 180 * 3.14
39
- pitch = pitch / 180 * 3.14
40
- roll = roll / 180 * 3.14
41
-
42
- roll = roll.unsqueeze(1)
43
- pitch = pitch.unsqueeze(1)
44
- yaw = yaw.unsqueeze(1)
45
-
46
- pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
47
- torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
48
- torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
49
- pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
50
-
51
- yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
52
- torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
53
- -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
54
- yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
55
-
56
- roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
57
- torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
58
- torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
59
- roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
60
-
61
-
62
- rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
63
-
64
- return rot_mat
65
-
66
- def keypoint_transformation(kp_canonical, he, wo_exp=False):
67
- kp = kp_canonical['value'] # (bs, k, 3)
68
- yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
69
- yaw = headpose_pred_to_degree(yaw)
70
- pitch = headpose_pred_to_degree(pitch)
71
- roll = headpose_pred_to_degree(roll)
72
-
73
- if 'yaw_in' in he:
74
- yaw = he['yaw_in']
75
- if 'pitch_in' in he:
76
- pitch = he['pitch_in']
77
- if 'roll_in' in he:
78
- roll = he['roll_in']
79
-
80
- rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
81
-
82
- t, exp = he['t'], he['exp']
83
- if wo_exp:
84
- exp = exp*0
85
-
86
- # keypoint rotation
87
- kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
88
-
89
- # keypoint translation
90
- t[:, 0] = t[:, 0]*0
91
- t[:, 2] = t[:, 2]*0
92
- t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
93
- kp_t = kp_rotated + t
94
-
95
- # add expression deviation
96
- exp = exp.view(exp.shape[0], -1, 3)
97
- kp_transformed = kp_t + exp
98
-
99
- return {'value': kp_transformed}
100
-
101
-
102
- # def make_animation(source_image, source_semantics, target_semantics,
103
- # generator, kp_detector, he_estimator, mapping,
104
- # yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
105
- # use_exp=True):
106
- # with torch.no_grad():
107
- # predictions = []
108
-
109
- # kp_canonical = kp_detector(source_image)
110
- # he_source = mapping(source_semantics)
111
- # kp_source = keypoint_transformation(kp_canonical, he_source)
112
-
113
-
114
- # for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
115
- # target_semantics_frame = target_semantics[:, frame_idx]
116
- # he_driving = mapping(target_semantics_frame)
117
- # if yaw_c_seq is not None:
118
- # he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
119
- # if pitch_c_seq is not None:
120
- # he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
121
- # if roll_c_seq is not None:
122
- # he_driving['roll_in'] = roll_c_seq[:, frame_idx]
123
-
124
- # kp_driving = keypoint_transformation(kp_canonical, he_driving)
125
-
126
- # #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
127
- # #kp_driving_initial=kp_driving_initial)
128
- # kp_norm = kp_driving
129
- # out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
130
- # '''
131
- # source_image_new = out['prediction'].squeeze(1)
132
- # kp_canonical_new = kp_detector(source_image_new)
133
- # he_source_new = he_estimator(source_image_new)
134
- # kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
135
- # kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
136
- # out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
137
- # '''
138
- # predictions.append(out['prediction'])
139
- # torch.cuda.empty_cache()
140
- # predictions_ts = torch.stack(predictions, dim=1)
141
- # return predictions_ts
142
-
143
- import torch
144
- from torch.cuda.amp import autocast
145
-
146
- def make_animation(source_image, source_semantics, target_semantics,
147
- generator, kp_detector, he_estimator, mapping,
148
- yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
149
- use_exp=True):
150
-
151
- # device='cuda'
152
- # # Move inputs to GPU
153
- # source_image = source_image.to(device)
154
- # source_semantics = source_semantics.to(device)
155
- # target_semantics = target_semantics.to(device)
156
-
157
- with torch.no_grad(): # No gradients needed
158
- predictions = []
159
- kp_canonical = kp_detector(source_image)
160
- he_source = mapping(source_semantics)
161
- kp_source = keypoint_transformation(kp_canonical, he_source)
162
-
163
- for frame_idx in tqdm(range(target_semantics.shape[1]), desc='Face Renderer:', unit='frame'):
164
- target_semantics_frame = target_semantics[:, frame_idx]
165
- he_driving = mapping(target_semantics_frame)
166
-
167
- if yaw_c_seq is not None:
168
- he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
169
- if pitch_c_seq is not None:
170
- he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
171
- if roll_c_seq is not None:
172
- he_driving['roll_in'] = roll_c_seq[:, frame_idx]
173
-
174
- kp_driving = keypoint_transformation(kp_canonical, he_driving)
175
- kp_norm = kp_driving
176
-
177
- # Use mixed precision for faster computation
178
- with autocast():
179
- out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
180
-
181
- predictions.append(out['prediction'])
182
-
183
- # Optional: Explicitly synchronize (use only if necessary)
184
- torch.cuda.synchronize()
185
-
186
- # Stack predictions into a single tensor
187
- predictions_ts = torch.stack(predictions, dim=1)
188
-
189
- return predictions_ts
190
-
191
-
192
- class AnimateModel(torch.nn.Module):
193
- """
194
- Merge all generator related updates into single model for better multi-gpu usage
195
- """
196
-
197
- def __init__(self, generator, kp_extractor, mapping):
198
- super(AnimateModel, self).__init__()
199
- self.kp_extractor = kp_extractor
200
- self.generator = generator
201
- self.mapping = mapping
202
-
203
- self.kp_extractor.eval()
204
- self.generator.eval()
205
- self.mapping.eval()
206
-
207
- def forward(self, x):
208
-
209
- source_image = x['source_image']
210
- source_semantics = x['source_semantics']
211
- target_semantics = x['target_semantics']
212
- yaw_c_seq = x['yaw_c_seq']
213
- pitch_c_seq = x['pitch_c_seq']
214
- roll_c_seq = x['roll_c_seq']
215
-
216
- predictions_video = make_animation(source_image, source_semantics, target_semantics,
217
- self.generator, self.kp_extractor,
218
- self.mapping, use_exp = True,
219
- yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
220
-
221
- return predictions_video