Spanicin commited on
Commit
cf2e13b
·
verified ·
1 Parent(s): bf35a83

Update src/facerender/modules/make_animation.py

Browse files
src/facerender/modules/make_animation.py CHANGED
@@ -192,6 +192,55 @@ def make_animation(source_image, source_semantics, target_semantics,
192
 
193
  # return predictions_ts
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
  class AnimateModel(torch.nn.Module):
 
192
 
193
  # return predictions_ts
194
 
195
+ import torch
196
+ from torch.cuda.amp import autocast
197
+ from tqdm import tqdm
198
+
199
+ def make_animation(source_image, source_semantics, target_semantics,
200
+ generator, kp_detector, he_estimator, mapping,
201
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
202
+ use_exp=True, batch_size=8):
203
+ device = 'cuda'
204
+ source_image = source_image.to(device)
205
+ source_semantics = source_semantics.to(device)
206
+ target_semantics = target_semantics.to(device)
207
+
208
+ with torch.no_grad():
209
+ predictions = []
210
+ kp_canonical = kp_detector(source_image)
211
+ he_source = mapping(source_semantics)
212
+ kp_source = keypoint_transformation(kp_canonical, he_source)
213
+
214
+ num_frames = target_semantics.shape[1]
215
+ for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
216
+ end_idx = min(start_idx + batch_size, num_frames)
217
+ target_semantics_batch = target_semantics[:, start_idx:end_idx]
218
+ he_driving = mapping(target_semantics_batch)
219
+
220
+ if yaw_c_seq is not None:
221
+ he_driving['yaw_in'] = yaw_c_seq[:, start_idx:end_idx]
222
+ if pitch_c_seq is not None:
223
+ he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx]
224
+ if roll_c_seq is not None:
225
+ he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx]
226
+
227
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
228
+ kp_norm = kp_driving
229
+
230
+ with autocast():
231
+ out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
232
+
233
+ predictions.append(out['prediction'])
234
+
235
+ # Optional: Explicitly synchronize (use only if necessary)
236
+ torch.cuda.synchronize()
237
+
238
+ # Stack predictions into a single tensor
239
+ predictions_ts = torch.stack(predictions, dim=1)
240
+
241
+ return predictions_ts
242
+
243
+
244
 
245
 
246
  class AnimateModel(torch.nn.Module):