Spaces:
Paused
Paused
Update src/facerender/modules/make_animation.py
Browse files
src/facerender/modules/make_animation.py
CHANGED
@@ -192,6 +192,55 @@ def make_animation(source_image, source_semantics, target_semantics,
|
|
192 |
|
193 |
# return predictions_ts
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
class AnimateModel(torch.nn.Module):
|
|
|
192 |
|
193 |
# return predictions_ts
|
194 |
|
195 |
+
import torch
|
196 |
+
from torch.cuda.amp import autocast
|
197 |
+
from tqdm import tqdm
|
198 |
+
|
199 |
+
def make_animation(source_image, source_semantics, target_semantics,
|
200 |
+
generator, kp_detector, he_estimator, mapping,
|
201 |
+
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
|
202 |
+
use_exp=True, batch_size=8):
|
203 |
+
device = 'cuda'
|
204 |
+
source_image = source_image.to(device)
|
205 |
+
source_semantics = source_semantics.to(device)
|
206 |
+
target_semantics = target_semantics.to(device)
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
predictions = []
|
210 |
+
kp_canonical = kp_detector(source_image)
|
211 |
+
he_source = mapping(source_semantics)
|
212 |
+
kp_source = keypoint_transformation(kp_canonical, he_source)
|
213 |
+
|
214 |
+
num_frames = target_semantics.shape[1]
|
215 |
+
for start_idx in tqdm(range(0, num_frames, batch_size), desc='Face Renderer:', unit='batch'):
|
216 |
+
end_idx = min(start_idx + batch_size, num_frames)
|
217 |
+
target_semantics_batch = target_semantics[:, start_idx:end_idx]
|
218 |
+
he_driving = mapping(target_semantics_batch)
|
219 |
+
|
220 |
+
if yaw_c_seq is not None:
|
221 |
+
he_driving['yaw_in'] = yaw_c_seq[:, start_idx:end_idx]
|
222 |
+
if pitch_c_seq is not None:
|
223 |
+
he_driving['pitch_in'] = pitch_c_seq[:, start_idx:end_idx]
|
224 |
+
if roll_c_seq is not None:
|
225 |
+
he_driving['roll_in'] = roll_c_seq[:, start_idx:end_idx]
|
226 |
+
|
227 |
+
kp_driving = keypoint_transformation(kp_canonical, he_driving)
|
228 |
+
kp_norm = kp_driving
|
229 |
+
|
230 |
+
with autocast():
|
231 |
+
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
|
232 |
+
|
233 |
+
predictions.append(out['prediction'])
|
234 |
+
|
235 |
+
# Optional: Explicitly synchronize (use only if necessary)
|
236 |
+
torch.cuda.synchronize()
|
237 |
+
|
238 |
+
# Stack predictions into a single tensor
|
239 |
+
predictions_ts = torch.stack(predictions, dim=1)
|
240 |
+
|
241 |
+
return predictions_ts
|
242 |
+
|
243 |
+
|
244 |
|
245 |
|
246 |
class AnimateModel(torch.nn.Module):
|