import torch import random from utils.commons.base_task import BaseTask from utils.commons.dataset_utils import data_loader from utils.commons.hparams import hparams from utils.commons.tensor_utils import tensors_to_scalars from utils.nn.schedulers import CosineSchedule, NoneSchedule from utils.nn.model_utils import print_arch, num_params from utils.commons.ckpt_utils import load_ckpt from modules.syncnet.models import LandmarkHubertSyncNet from tasks.os_avatar.dataset_utils.syncnet_dataset import SyncNet_Dataset from data_util.face3d_helper import Face3DHelper class ScheduleForSyncNet(NoneSchedule): def __init__(self, optimizer, lr): self.optimizer = optimizer self.constant_lr = self.lr = lr self.step(0) def step(self, num_updates): constant_lr = self.constant_lr self.lr = constant_lr lr = self.lr * hparams['lr_decay_rate'] ** (num_updates // hparams['lr_decay_interval']) # lr = max(lr, 5e-6) lr = max(lr, 5e-5) self.optimizer.param_groups[0]['lr'] = lr return self.lr class SyncNetTask(BaseTask): def __init__(self, hparams_=None): global hparams if hparams_ is not None: hparams = hparams_ self.hparams = hparams super().__init__() self.dataset_cls = SyncNet_Dataset def on_train_start(self): for n, m in self.model.named_children(): num_params(m, model_name=n) def build_model(self): if self.hparams is not None: hparams = self.hparams # lm_dim = 468*3 # lip part in idexp_lm3d self.face3d_helper = Face3DHelper(use_gpu=False, keypoint_mode='lm68') if hparams.get('syncnet_keypoint_mode', 'lip') == 'lip': lm_dim = 20*3 # lip part in idexp_lm3d elif hparams['syncnet_keypoint_mode'] == 'lm68': lm_dim = 68*3 # lip part in idexp_lm3d elif hparams['syncnet_keypoint_mode'] == 'centered_lip': lm_dim = 20*3 # lip part in idexp_lm3d elif hparams['syncnet_keypoint_mode'] == 'centered_lip2d': lm_dim = 20*2 # lip part in idexp_lm3d elif hparams['syncnet_keypoint_mode'] == 'lm468': lm_dim = 468*3 # lip part in idexp_lm3d self.face3d_helper = Face3DHelper(use_gpu=False, keypoint_mode='mediapipe') if hparams['audio_type'] == 'hubert': audio_dim = 1024 # hubert elif hparams['audio_type'] == 'mfcc': audio_dim = 13 # hubert elif hparams['audio_type'] == 'mel': audio_dim = 80 # hubert self.model = LandmarkHubertSyncNet(lm_dim, audio_dim, num_layers_per_block=hparams['syncnet_num_layers_per_block'], base_hid_size=hparams['syncnet_base_hid_size'], out_dim=hparams['syncnet_out_hid_size']) print_arch(self.model) if hparams.get('init_from_ckpt', '') != '': ckpt_dir = hparams.get('init_from_ckpt', '') load_ckpt(self.model, ckpt_dir, model_name='model', strict=False) return self.model def build_optimizer(self, model): if self.hparams is not None: hparams = self.hparams self.optimizer = optimizer = torch.optim.Adam( model.parameters(), lr=hparams['lr'], betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) return optimizer # def build_scheduler(self, optimizer): # return CosineSchedule(optimizer, hparams['lr'], warmup_updates=0, total_updates=40_0000) def build_scheduler(self, optimizer): return ScheduleForSyncNet(optimizer, hparams['lr']) @data_loader def train_dataloader(self): train_dataset = self.dataset_cls(prefix='train') self.train_dl = train_dataset.get_dataloader() return self.train_dl @data_loader def val_dataloader(self): val_dataset = self.dataset_cls(prefix='val') self.val_dl = val_dataset.get_dataloader() return self.val_dl @data_loader def test_dataloader(self): val_dataset = self.dataset_cls(prefix='val') self.val_dl = val_dataset.get_dataloader() return self.val_dl ########################## # training and validation ########################## def run_model(self, sample, infer=False, batch_size=1024): """ render or train on a single-frame :param sample: a batch of data :param infer: bool, run in infer mode :return: if not infer: return losses, model_out if infer: return model_out """ if self.hparams is not None: hparams = self.hparams if sample is None or len(sample) == 0: return None model_out = {} if 'idexp_lm3d' not in sample: with torch.no_grad(): b,t,_ = sample['exp'].shape idexp_lm3d = self.face3d_helper.reconstruct_idexp_lm3d(sample['id'], sample['exp']).reshape([b,t,-1,3]) else: b,t,*_ = sample['idexp_lm3d'].shape idexp_lm3d = sample['idexp_lm3d'] if hparams.get('syncnet_keypoint_mode', 'lip') == 'lip': mouth_lm = idexp_lm3d[:,:, 48:68,:].reshape([b, t, 20*3]) # [b, t, 60] elif hparams.get('syncnet_keypoint_mode', 'lip') == 'centered_lip': mouth_lm = idexp_lm3d[:,:, 48:68, ].reshape([b, t, 20, 3]) # [b, t, 60] mean_mouth_lm = self.face3d_helper.key_mean_shape[48:68] mouth_lm = mouth_lm / 10 + mean_mouth_lm.reshape([1, 1, 20, 3]) - mean_mouth_lm.reshape([1, 1, 20, 3]).mean(dim=-2) # to center mouth_lm = mouth_lm.reshape([b, t, 20*3]) * 10 elif hparams.get('syncnet_keypoint_mode', 'lip') == 'centered_lip2d': mouth_lm = idexp_lm3d[:,:, 48:68, ].reshape([b, t, 20, 3]) # [b, t, 60] mean_mouth_lm = self.face3d_helper.key_mean_shape[48:68] mouth_lm = mouth_lm / 10 + mean_mouth_lm.reshape([1, 1, 20, 3]) - mean_mouth_lm.reshape([1, 1, 20, 3]).mean(dim=-2) # to center mouth_lm = mouth_lm[..., :2] mouth_lm = mouth_lm.reshape([b, t, 20*2]) * 10 elif hparams['syncnet_keypoint_mode'] == 'lm68': mouth_lm = idexp_lm3d.reshape([b, t, 68*3]) elif hparams['syncnet_keypoint_mode'] == 'lm468': mouth_lm = idexp_lm3d.reshape([b, t, 468*3]) if hparams['audio_type'] == 'hubert': mel = sample['hubert'] # [b, 2t, 1024] elif hparams['audio_type'] == 'mfcc': mel = sample['mfcc'] / 100 # [b, 2t, 1024] elif hparams['audio_type'] == 'mel': mel = sample['mfcc'] # [b, 2t, 1024] y_mask = sample['y_mask'] y_len = y_mask.sum(dim=1).min().item() # [B, T] len_mouth_slice = 5 # 5 frames denotes 0.2s, which is a appropriate length for sync check len_mel_slice = len_mouth_slice * 2 if infer: phase_ratio_dict = { 'pos' : 1.0, } else: phase_ratio_dict = { 'pos' : 0.4, 'neg_same_people_small_offset_ratio' : 0.3, 'neg_same_people_large_offset_ratio' : 0.2, 'neg_diff_people_random_offset_ratio': 0.1 } mouth_lst, mel_lst, label_lst = [], [], [] for phase_key, phase_ratio in phase_ratio_dict.items(): num_samples = int(batch_size * phase_ratio) if phase_key == 'pos': phase_mel_lst = [] phase_mouth_lst = [] num_iters = max(1, num_samples // len(mouth_lm)) for i in range(num_iters): t_start = random.randint(0, y_len-len_mouth_slice-1) phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] assert phase_mouth.shape[1] == len_mouth_slice phase_mel = mel[:, t_start*2 : t_start*2+len_mel_slice] phase_mouth_lst.append(phase_mouth) phase_mel_lst.append(phase_mel) phase_mouth = torch.cat(phase_mouth_lst) phase_mel = torch.cat(phase_mel_lst) mouth_lst.append(phase_mouth) mel_lst.append(phase_mel) label_lst.append(torch.ones([len(phase_mel)])) # 1 denotes pos samples elif phase_key in ['neg_same_people_small_offset_ratio', 'neg_same_people_large_offset_ratio']: phase_mel_lst = [] phase_mouth_lst = [] num_iters = max(1, num_samples // len(mouth_lm)) for i in range(num_iters): if phase_key == 'neg_same_people_small_offset_ratio': offset = random.choice([random.randint(-5,-2), random.randint(2,5)]) elif phase_key == 'neg_same_people_large_offset_ratio': offset = random.choice([random.randint(-10,-5), random.randint(5,10)]) else: ValueError() if offset < 0: t_start = random.randint(-offset, y_len-len_mouth_slice-1) else: t_start = random.randint(0, y_len-len_mouth_slice-1-offset) phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] assert phase_mouth.shape[1] == len_mouth_slice phase_mel = mel[:, (t_start+offset)*2:(t_start+offset)*2+len_mel_slice] phase_mouth_lst.append(phase_mouth) phase_mel_lst.append(phase_mel) phase_mouth = torch.cat(phase_mouth_lst) phase_mel = torch.cat(phase_mel_lst) mouth_lst.append(phase_mouth) mel_lst.append(phase_mel) label_lst.append(torch.zeros([len(phase_mel)])) # 0 denotes neg samples elif phase_key == 'neg_diff_people_random_offset_ratio': phase_mel_lst = [] phase_mouth_lst = [] num_iters = max(1, num_samples // len(mouth_lm)) for i in range(num_iters): offset = random.randint(-10, 10) if offset < 0: t_start = random.randint(-offset, y_len-len_mouth_slice-1) else: t_start = random.randint(0, y_len-len_mouth_slice-1-offset) phase_mouth = mouth_lm[:, t_start: t_start+len_mouth_slice] assert phase_mouth.shape[1] == len_mouth_slice sample_idx = list(range(len(mouth_lm))) random.shuffle(sample_idx) phase_mel = mel[sample_idx, (t_start+offset)*2:(t_start+offset)*2+len_mel_slice] phase_mouth_lst.append(phase_mouth) phase_mel_lst.append(phase_mel) phase_mouth = torch.cat(phase_mouth_lst) phase_mel = torch.cat(phase_mel_lst) mouth_lst.append(phase_mouth) mel_lst.append(phase_mel) label_lst.append(torch.zeros([len(phase_mel)])) # 0 denotes neg samples mel_clips = torch.cat(mel_lst) mouth_clips = torch.cat(mouth_lst) labels = torch.cat(label_lst).float().to(mel_clips.device) audio_embedding, mouth_embedding = self.model(mel_clips, mouth_clips) sync_loss, cosine_sim = self.model.cal_sync_loss(audio_embedding, mouth_embedding, labels, reduction='mean') if not infer: losses_out = {} model_out = {} losses_out['sync_loss'] = sync_loss losses_out['batch_size'] = len(mel_clips) model_out['cosine_sim'] = cosine_sim return losses_out, model_out else: model_out['sync_loss'] = sync_loss model_out['batch_size'] = len(mel_clips) return model_out def _training_step(self, sample, batch_idx, optimizer_idx): ret = self.run_model(sample, infer=False, batch_size=hparams['syncnet_num_clip_pairs']) if ret is None: return None loss_output, model_out = ret loss_weights = {} total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) return total_loss, loss_output def validation_start(self): pass @torch.no_grad() def validation_step(self, sample, batch_idx): outputs = {} outputs['losses'] = {} outputs['losses'], model_out = self.run_model(sample, infer=False, batch_size=8000) outputs = tensors_to_scalars(outputs) return outputs def validation_end(self, outputs): return super().validation_end(outputs) ##################### # Testing ##################### def test_start(self): pass @torch.no_grad() def test_step(self, sample, batch_idx): """ :param sample: :param batch_idx: :return: """ pass def test_end(self, outputs): pass