import torch import torch.nn as nn from models.lib.wav2vec import Wav2Vec2Model from models.utils import init_biased_mask, enc_dec_mask, PeriodicPositionalEncoding from base import BaseModel import pdb import os import random import torch.nn.functional as F #os.environ['PYOPENGL_PLATFORM'] = 'osmesa' class CodeTalker(BaseModel): def __init__(self, args): super(CodeTalker, self).__init__() """ audio: (batch_size, raw_wav) template: (batch_size, V*3) vertice: (batch_size, seq_len, V*3) """ self.args = args self.dataset = args.dataset self.audio_encoder = Wav2Vec2Model.from_pretrained(args.wav2vec2model_path) # wav2vec 2.0 weights initialization self.audio_encoder.feature_extractor._freeze_parameters() #self.audio_feature_map = nn.Linear(768, args.feature_dim) #wav2vec 2.0 weights initialization for multilingual encoder self.audio_feature_map = nn.Linear(1024, args.feature_dim) # motion encoder self.vertice_map = nn.Linear(args.vertice_dim, args.feature_dim) # periodic positional encoding self.PPE = PeriodicPositionalEncoding(args.feature_dim, period = args.period) # temporal bias self.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 600, period=args.period) decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=args.n_head, dim_feedforward=2*args.feature_dim, batch_first=True) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=args.num_layers) # motion decoder self.feat_map = nn.Linear(args.feature_dim, args.face_quan_num*args.zquant_dim, bias=False) # style embedding self.learnable_style_emb = nn.Embedding(len(args.train_subjects.split()), args.feature_dim) self.device = args.device nn.init.constant_(self.feat_map.weight, 0) # nn.init.constant_(self.feat_map.bias, 0) from models.stage1_vocaset import VQAutoEncoder self.autoencoder = VQAutoEncoder(args) temp = torch.load(args.vqvae_pretrained_path)['state_dict'] self.autoencoder.load_state_dict(torch.load(args.vqvae_pretrained_path)['state_dict']) for param in self.autoencoder.parameters(): param.requires_grad = False def forward(self, audio_name, audio, template, vertice, one_hot, criterion): # tgt_mask: :math:`(T, T)`. # memory_mask: :math:`(T, S)`. template = template.unsqueeze(1) # (1,1,V*3) # style embedding obj_embedding = self.learnable_style_emb(torch.argmax(one_hot, dim=1)) obj_embedding = obj_embedding.unsqueeze(1) frame_num = vertice.shape[1] # audio feature extraction hidden_states = self.audio_encoder(audio, self.dataset, frame_num=frame_num).last_hidden_state if self.dataset == "BIWI" or self.dataset=="multi": if hidden_states.shape[1]