|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import os |
|
from einops import rearrange |
|
import random |
|
|
|
from utils.commons.base_task import BaseTask |
|
from utils.commons.dataset_utils import data_loader |
|
from utils.commons.hparams import hparams |
|
from utils.commons.ckpt_utils import load_ckpt |
|
from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np |
|
from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad |
|
from utils.nn.schedulers import ExponentialSchedule |
|
from utils.nn.grad import get_grad_norm |
|
from utils.nn.model_utils import print_arch, num_params |
|
from utils.commons.face_alignment_utils import mouth_idx_in_mediapipe_mesh |
|
|
|
from modules.audio2motion.vae import VAEModel, PitchContourVAEModel |
|
from tasks.os_avatar.dataset_utils.audio2motion_dataset import Audio2Motion_Dataset |
|
from data_util.face3d_helper import Face3DHelper |
|
from data_gen.utils.mp_feature_extractors.face_landmarker import index_lm68_from_lm478 |
|
from modules.syncnet.models import LandmarkHubertSyncNet |
|
|
|
|
|
class Audio2MotionTask(BaseTask): |
|
def __init__(self): |
|
super().__init__() |
|
self.dataset_cls = Audio2Motion_Dataset |
|
if hparams["motion_type"] == 'id_exp': |
|
self.in_out_dim = 80 + 64 |
|
elif hparams["motion_type"] == 'exp': |
|
self.in_out_dim = 64 |
|
|
|
def build_model(self): |
|
if hparams['audio_type'] == 'hubert': |
|
audio_in_dim = 1024 |
|
elif hparams['audio_type'] == 'mfcc': |
|
audio_in_dim = 13 |
|
|
|
if hparams.get("use_pitch", False) is True: |
|
self.model = PitchContourVAEModel(hparams, in_out_dim=self.in_out_dim, audio_in_dim=audio_in_dim, use_prior_flow=hparams.get("use_flow", True)) |
|
else: |
|
self.model = VAEModel(in_out_dim=self.in_out_dim, audio_in_dim=audio_in_dim, use_prior_flow=hparams.get("use_flow", True)) |
|
|
|
if hparams.get('init_from_ckpt', '') != '': |
|
ckpt_dir = hparams.get('init_from_ckpt', '') |
|
load_ckpt(self.model, ckpt_dir, model_name='model', strict=True) |
|
|
|
self.face3d_helper = Face3DHelper(keypoint_mode='mediapipe', use_gpu=False) |
|
lm_dim = 468*3 |
|
|
|
hparams['syncnet_num_layers_per_block'] = 3 |
|
hparams['syncnet_base_hid_size'] = 128 |
|
hparams['syncnet_out_hid_size'] = 1024 |
|
self.syncnet = LandmarkHubertSyncNet(lm_dim, audio_in_dim, num_layers_per_block=hparams['syncnet_num_layers_per_block'], base_hid_size=hparams['syncnet_base_hid_size'], out_dim=hparams['syncnet_out_hid_size']) |
|
if hparams['syncnet_ckpt_dir']: |
|
load_ckpt(self.syncnet, hparams['syncnet_ckpt_dir']) |
|
return self.model |
|
|
|
def on_train_start(self): |
|
for n, m in self.model.named_children(): |
|
num_params(m, model_name=n) |
|
for n, m in self.model.vae.named_children(): |
|
num_params(m, model_name='vae.'+n) |
|
|
|
def build_optimizer(self, model): |
|
self.optimizer = optimizer = torch.optim.Adam( |
|
model.parameters(), |
|
lr=hparams['lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) |
|
return optimizer |
|
|
|
def build_scheduler(self, optimizer): |
|
return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) |
|
|
|
@data_loader |
|
def train_dataloader(self): |
|
if hparams['ds_name'] == 'Concat_voxceleb2_CMLR': |
|
train_dataset1 = self.dataset_cls(prefix='train', data_dir='data/binary/voxceleb2_audio2motion_kv') |
|
train_dataset2 = self.dataset_cls(prefix='train', data_dir='data/binary/CMLR_audio2motion_kv') |
|
train_dataset = BaseConcatDataset([train_dataset1,train_dataset2], prefix='train') |
|
elif hparams['ds_name'] == 'Weighted_Concat_voxceleb2_CMLR': |
|
train_dataset1 = self.dataset_cls(prefix='train', data_dir='data/binary/voxceleb2_audio2motion_kv') |
|
train_dataset2 = self.dataset_cls(prefix='train', data_dir='data/binary/CMLR_audio2motion_kv') |
|
train_dataset = WeightedConcatDataset([train_dataset1,train_dataset2], [0.5, 0.5], prefix='train') |
|
else: |
|
train_dataset = self.dataset_cls(prefix='train') |
|
self.train_dl = train_dataset.get_dataloader() |
|
return self.train_dl |
|
|
|
@data_loader |
|
def val_dataloader(self): |
|
val_dataset = self.dataset_cls(prefix='val') |
|
self.val_dl = val_dataset.get_dataloader() |
|
return self.val_dl |
|
|
|
@data_loader |
|
def test_dataloader(self): |
|
val_dataset = self.dataset_cls(prefix='val') |
|
self.val_dl = val_dataset.get_dataloader() |
|
return self.val_dl |
|
|
|
|
|
|
|
|
|
def run_model(self, sample, infer=False, temperature=1.0, sync_batch_size=1024): |
|
""" |
|
render or train on a single-frame |
|
:param sample: a batch of data |
|
:param infer: bool, run in infer mode |
|
:return: |
|
if not infer: |
|
return losses, model_out |
|
if infer: |
|
return model_out |
|
""" |
|
model_out = {} |
|
if hparams['audio_type'] == 'hubert': |
|
sample['audio'] = sample['hubert'] |
|
elif hparams['audio_type'] == 'mfcc': |
|
sample['audio'] = sample['mfcc'] / 100 |
|
elif hparams['audio_type'] == 'mel': |
|
sample['audio'] = sample['mel'] |
|
|
|
if hparams.get("blink_mode", 'none') != 'none': |
|
blink = F.interpolate(sample[hparams['blink_mode']].permute(0,2,1).float(), scale_factor=2).permute(0,2,1).long() |
|
sample['blink'] = blink |
|
bs = sample['audio'].shape[0] |
|
|
|
if infer: |
|
self.model(sample, model_out, train=False, temperature=temperature) |
|
return model_out |
|
|
|
else: |
|
losses_out = {} |
|
if hparams["motion_type"] == 'id_exp': |
|
x_gt = torch.cat([sample['id'], sample['exp']],dim=-1) |
|
sample['y'] = x_gt |
|
self.model(sample, model_out, train=True) |
|
x_pred = model_out['pred'].reshape([bs, -1, 80+64]) |
|
x_mask = model_out['mask'].reshape([bs, -1]) |
|
losses_out['kl'] = model_out['loss_kl'] |
|
id_pred = x_pred[:, :, :80] |
|
exp_pred = x_pred[:, :, 80:] |
|
losses_out['lap_id'] = self.lap_loss(id_pred, x_mask) |
|
losses_out['lap_exp'] = self.lap_loss(exp_pred, x_mask) |
|
pred_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(id_pred, exp_pred).reshape([bs, x_mask.shape[1], -1]) |
|
gt_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([bs, x_mask.shape[1], -1]) |
|
losses_out['mse_idexp_lm3d'] = self.lm468_mse_loss(gt_idexp_lm3d, pred_idexp_lm3d, x_mask) |
|
losses_out['l2_reg_id'] = self.l2_reg_loss(id_pred, x_mask) |
|
losses_out['l2_reg_exp'] = self.l2_reg_loss(exp_pred, x_mask) |
|
|
|
gt_lm2d = self.face3d_helper.reconstruct_lm2d(sample['id'], sample['exp'], sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) |
|
pred_lm2d = self.face3d_helper.reconstruct_lm2d(id_pred, exp_pred, sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) |
|
losses_out['mse_lm2d'] = self.lm468_mse_loss(gt_lm2d, pred_lm2d, x_mask) |
|
|
|
elif hparams["motion_type"] == 'exp': |
|
x_gt = sample['exp'] |
|
sample['y'] = x_gt |
|
self.model(sample, model_out, train=True) |
|
x_pred = model_out['pred'].reshape([bs, -1, 64]) |
|
x_mask = model_out['mask'].reshape([bs, -1]) |
|
losses_out['kl'] = model_out['loss_kl'] |
|
exp_pred = x_pred[:, :, :] |
|
losses_out['lap_exp'] = self.lap_loss(exp_pred, x_mask) |
|
if hparams.get("ref_id_mode",'first_frame') == 'first_frame': |
|
id_pred = sample['id'][:,0:1, :].repeat([1,exp_pred.shape[1],1]) |
|
elif hparams.get("ref_id_mode",'first_frame') == 'random_frame': |
|
max_y = x_mask.sum(dim=1).min().item() |
|
idx = random.randint(0, max_y-1) |
|
id_pred = sample['id'][:,idx:idx+1, :].repeat([1,exp_pred.shape[1],1]) |
|
gt_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([bs, x_mask.shape[1], -1]) |
|
pred_idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(id_pred, exp_pred).reshape([bs, x_mask.shape[1], -1]) |
|
losses_out['mse_exp'] = self.mse_loss(x_gt, x_pred, x_mask) |
|
losses_out['mse_idexp_lm3d'] = self.lm468_mse_loss(gt_idexp_lm3d, pred_idexp_lm3d, x_mask) |
|
losses_out['l2_reg_exp'] = self.l2_reg_loss(exp_pred, x_mask) |
|
|
|
gt_lm2d = self.face3d_helper.reconstruct_lm2d(sample['id'], sample['exp'], sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) |
|
pred_lm2d = self.face3d_helper.reconstruct_lm2d(id_pred, exp_pred, sample['euler'], sample['trans']).reshape([bs, x_mask.shape[1], -1]) |
|
|
|
|
|
|
|
mouth_lm3d = pred_idexp_lm3d.reshape([bs, x_pred.shape[1], 468*3]) |
|
|
|
|
|
if hparams['audio_type'] == 'hubert': |
|
mel = sample['hubert'] |
|
elif hparams['audio_type'] == 'mfcc': |
|
mel = sample['mfcc'] / 100 |
|
elif hparams['audio_type'] == 'mel': |
|
mel = sample['mel'] |
|
|
|
num_clips_for_syncnet = 8096 |
|
len_mouth_slice = 5 |
|
len_mel_slice = len_mouth_slice * 2 |
|
num_iters = max(1, num_clips_for_syncnet // len(mouth_lm3d)) |
|
mouth_clip_lst = [] |
|
mel_clip_lst = [] |
|
x_mask_clip_lst = [] |
|
for i in range(num_iters): |
|
t_start = random.randint(0, x_pred.shape[1]-len_mouth_slice-1) |
|
mouth_clip = mouth_lm3d[:, t_start: t_start+len_mouth_slice] |
|
x_mask_clip = x_mask[:, t_start: t_start+len_mouth_slice] |
|
assert mouth_clip.shape[1] == len_mouth_slice |
|
mel_clip = mel[:, t_start*2 : t_start*2+len_mel_slice] |
|
mouth_clip_lst.append(mouth_clip) |
|
mel_clip_lst.append(mel_clip) |
|
x_mask_clip_lst.append(x_mask_clip) |
|
mouth_clips = torch.cat(mouth_clip_lst) |
|
mel_clips = torch.cat(mel_clip_lst) |
|
x_mask_clips = torch.cat(x_mask_clip_lst) |
|
x_mask_clips = (x_mask_clips.sum(dim=1) == x_mask_clips.shape[1]).float() |
|
audio_embedding, mouth_embedding = self.syncnet.forward(mel_clips, mouth_clips) |
|
sync_loss, _ = self.syncnet.cal_sync_loss(audio_embedding, mouth_embedding, 1., reduction='none') |
|
losses_out['sync_lip_lm3d'] = (sync_loss * x_mask_clips).sum() / x_mask_clips.sum() |
|
return losses_out, model_out |
|
|
|
def kl_annealing(self, num_updates, max_lambda=0.4, t1=2000, t2=2000): |
|
""" |
|
Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing |
|
https://aclanthology.org/N19-1021.pdf |
|
""" |
|
T = t1 + t2 |
|
num_updates = num_updates % T |
|
if num_updates < t1: |
|
return num_updates / t1 * max_lambda |
|
else: |
|
return max_lambda |
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx): |
|
loss_output, model_out = self.run_model(sample) |
|
loss_weights = { |
|
'kl': self.kl_annealing(self.global_step, max_lambda=hparams['lambda_kl'], t1=hparams['lambda_kl_t1'], t2=hparams['lambda_kl_t2']), |
|
'mse_exp': hparams.get("lambda_mse_exp", 0.1), |
|
'mse_idexp_lm3d': hparams.get("lambda_mse_lm3d", 1.), |
|
'lap_id': hparams.get("lambda_lap_id", 1.), |
|
'lap_exp': hparams.get("lambda_lap_exp", 1.), |
|
'l2_reg_id': hparams.get("lambda_l2_reg_id", 0.), |
|
'l2_reg_exp': hparams.get("lambda_l2_reg_exp", 0.0), |
|
'sync_lip_lm3d': hparams.get("lambda_sync_lm3d", 0.2), |
|
'mse_lm2d': hparams.get("lambda_mse_lm2d", 0.) |
|
} |
|
|
|
total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
|
|
return total_loss, loss_output |
|
|
|
def validation_start(self): |
|
pass |
|
|
|
@torch.no_grad() |
|
def validation_step(self, sample, batch_idx): |
|
outputs = {} |
|
outputs['losses'] = {} |
|
outputs['losses'], model_out = self.run_model(sample, infer=False, sync_batch_size=10000) |
|
outputs = tensors_to_scalars(outputs) |
|
return outputs |
|
|
|
def validation_end(self, outputs): |
|
return super().validation_end(outputs) |
|
|
|
|
|
|
|
|
|
def test_start(self): |
|
self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') |
|
os.makedirs(self.gen_dir, exist_ok=True) |
|
|
|
@torch.no_grad() |
|
def test_step(self, sample, batch_idx): |
|
""" |
|
:param sample: |
|
:param batch_idx: |
|
:return: |
|
""" |
|
outputs = {} |
|
outputs['losses'], model_out = self.run_model(sample, infer=True) |
|
pred_exp = model_out['pred'] |
|
self.save_result(pred_exp, "pred_exp_val" , self.gen_dir) |
|
if hparams['save_gt']: |
|
base_fn = f"gt_exp_val" |
|
self.save_result(sample['exp'], base_fn , self.gen_dir) |
|
return outputs |
|
|
|
def test_end(self, outputs): |
|
pass |
|
|
|
@staticmethod |
|
def save_result(exp_arr, base_fname, gen_dir): |
|
exp_arr = convert_to_np(exp_arr) |
|
np.save(f"{gen_dir}/{base_fname}.npy", exp_arr) |
|
|
|
def get_grad(self, opt_idx): |
|
grad_dict = { |
|
'grad/model': get_grad_norm(self.model), |
|
} |
|
return grad_dict |
|
|
|
|
|
def lm468_mse_loss(self, proj_lan, gt_lan, x_mask): |
|
b,t,c= proj_lan.shape |
|
|
|
loss = ((proj_lan - gt_lan) ** 2) * x_mask[:,:, None] |
|
loss = loss.reshape([b,t,468,-1]) |
|
|
|
unmatch_mask = [93, 127, 132, 234, 323, 356, 361, 454] |
|
upper_eye = [161,160,159,158,157] + [388,387,386,385,384] |
|
eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] |
|
inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] |
|
outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] |
|
|
|
weights = torch.ones_like(loss) |
|
weights[:, :, eye] = 3 |
|
weights[:, :, upper_eye] = 20 |
|
weights[:, :, inner_lip] = 5 |
|
weights[:, :, outer_lip] = 5 |
|
weights[:, :, unmatch_mask] = 0 |
|
|
|
loss = loss.reshape([b,t,c]) |
|
weights = weights.reshape([b,t,c]) |
|
return (loss * weights).sum() / (x_mask.sum()*c) |
|
|
|
def lm68_mse_loss(self, proj_lan, gt_lan, x_mask): |
|
b,t,c= proj_lan.shape |
|
|
|
loss = ((proj_lan - gt_lan) ** 2) * x_mask[:,:, None] |
|
loss = loss.reshape([b,t,68,3]) |
|
|
|
weights = torch.ones_like(loss) |
|
weights[:, :, 36:48, :] = 5 |
|
weights[:, :, -8:, :] = 5 |
|
weights[:, :, 28:31, :] = 5 |
|
loss = loss.reshape([b,t,c]) |
|
weights = weights.reshape([b,t,c]) |
|
return (loss * weights).sum() / (x_mask.sum()*c) |
|
|
|
def l2_reg_loss(self, x_pred, x_mask): |
|
|
|
error = (x_pred ** 2) * x_mask[:,:, None] |
|
num_frame = x_mask.sum() |
|
return error.sum() / (num_frame * self.in_out_dim) |
|
|
|
def lap_loss(self, in_tensor, x_mask): |
|
|
|
b,t,c = in_tensor.shape |
|
in_tensor = F.pad(in_tensor, pad=(0,0,1,1)) |
|
in_tensor = rearrange(in_tensor, "b t c -> (b c) t").unsqueeze(1) |
|
lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) |
|
out_tensor = F.conv1d(in_tensor, lap_kernel) |
|
out_tensor = out_tensor.squeeze(1) |
|
out_tensor = rearrange(out_tensor, "(b c) t -> b t c", b=b, t=t) |
|
loss_lap = (out_tensor**2) * x_mask.unsqueeze(-1) |
|
return loss_lap.sum() / (x_mask.sum()*c) |
|
|
|
def mse_loss(self, x_gt, x_pred, x_mask): |
|
|
|
error = (x_pred - x_gt) * x_mask[:,:, None] |
|
num_frame = x_mask.sum() |
|
return (error ** 2).sum() / (num_frame * self.in_out_dim) |
|
|
|
def mae_loss(self, x_gt, x_pred, x_mask): |
|
|
|
error = (x_pred - x_gt) * x_mask[:,:, None] |
|
num_frame = x_mask.sum() |
|
return error.abs().sum() / (num_frame * self.in_out_dim) |
|
|
|
def vel_loss(self, x_pred, x_mask): |
|
|
|
error = (x_pred[:, 1:] - x_pred[:, :-1]) * x_mask[:,1:, None] |
|
num_frame = x_mask.sum() |
|
return (error).abs().sum() / (num_frame * self.in_out_dim) |
|
|
|
def continuity_loss(self, x_gt, x_pred, x_mask): |
|
|
|
diff_x_pred = x_pred[:,1:] - x_pred[:,:-1] |
|
diff_x_gt = x_gt[:,1:] - x_gt[:,:-1] |
|
error = (diff_x_pred[:,:,:] - diff_x_gt[:,:,:]) * x_mask[:,1:,None] |
|
init_error = x_pred[:,0,:] - x_gt[:,0,:] |
|
num_frame = x_mask.sum() |
|
return (error.pow(2).sum() + init_error.pow(2).sum()) / (num_frame * self.in_out_dim) |