|
''' |
|
not exactly the same as the official repo but the results are good |
|
''' |
|
import sys |
|
import os |
|
|
|
from transformers import Wav2Vec2Processor |
|
|
|
from .wav2vec import Wav2Vec2Model |
|
from torchaudio.sox_effects import apply_effects_tensor |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchaudio as ta |
|
import math |
|
from nets.layers import SeqEncoder1D, SeqTranslator1D, ConvNormRelu |
|
|
|
|
|
""" from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ |
|
|
|
|
|
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000): |
|
""" |
|
:param audio: 1 x T tensor containing a 16kHz audio signal |
|
:param frame_rate: frame rate for video (we need one audio chunk per video frame) |
|
:param chunk_size: number of audio samples per chunk |
|
:return: num_chunks x chunk_size tensor containing sliced audio |
|
""" |
|
samples_per_frame = 16000 // frame_rate |
|
padding = (chunk_size - samples_per_frame) // 2 |
|
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0) |
|
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame)) |
|
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0) |
|
return audio |
|
|
|
|
|
class MeshtalkEncoder(nn.Module): |
|
def __init__(self, latent_dim: int = 128, model_name: str = 'audio_encoder'): |
|
""" |
|
:param latent_dim: size of the latent audio embedding |
|
:param model_name: name of the model, used to load and save the model |
|
""" |
|
super().__init__() |
|
|
|
self.melspec = ta.transforms.MelSpectrogram( |
|
sample_rate=16000, n_fft=2048, win_length=800, hop_length=160, n_mels=80 |
|
) |
|
|
|
conv_len = 5 |
|
self.convert_dimensions = torch.nn.Conv1d(80, 128, kernel_size=conv_len) |
|
self.weights_init(self.convert_dimensions) |
|
self.receptive_field = conv_len |
|
|
|
convs = [] |
|
for i in range(6): |
|
dilation = 2 * (i % 3 + 1) |
|
self.receptive_field += (conv_len - 1) * dilation |
|
convs += [torch.nn.Conv1d(128, 128, kernel_size=conv_len, dilation=dilation)] |
|
self.weights_init(convs[-1]) |
|
self.convs = torch.nn.ModuleList(convs) |
|
self.code = torch.nn.Linear(128, latent_dim) |
|
|
|
self.apply(lambda x: self.weights_init(x)) |
|
|
|
def weights_init(self, m): |
|
if isinstance(m, torch.nn.Conv1d): |
|
torch.nn.init.xavier_uniform_(m.weight) |
|
try: |
|
torch.nn.init.constant_(m.bias, .01) |
|
except: |
|
pass |
|
|
|
def forward(self, audio: torch.Tensor): |
|
""" |
|
:param audio: B x T x 16000 Tensor containing 1 sec of audio centered around the current time frame |
|
:return: code: B x T x latent_dim Tensor containing a latent audio code/embedding |
|
""" |
|
B, T = audio.shape[0], audio.shape[1] |
|
x = self.melspec(audio).squeeze(1) |
|
x = torch.log(x.clamp(min=1e-10, max=None)) |
|
if T == 1: |
|
x = x.unsqueeze(1) |
|
|
|
|
|
x = x.view(-1, x.shape[2], x.shape[3]) |
|
x = F.leaky_relu(self.convert_dimensions(x), .2) |
|
|
|
|
|
for conv in self.convs: |
|
x_ = F.leaky_relu(conv(x), .2) |
|
if self.training: |
|
x_ = F.dropout(x_, .2) |
|
l = (x.shape[2] - x_.shape[2]) // 2 |
|
x = (x[:, :, l:-l] + x_) / 2 |
|
|
|
x = torch.mean(x, dim=-1) |
|
x = x.view(B, T, x.shape[-1]) |
|
x = self.code(x) |
|
|
|
return {"code": x} |
|
|
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__(self, in_dim, out_dim, identity=False, num_classes=0): |
|
super().__init__() |
|
self.identity = identity |
|
if self.identity: |
|
in_dim = in_dim + 64 |
|
self.id_mlp = nn.Conv1d(num_classes, 64, 1, 1) |
|
self.first_net = SeqTranslator1D(in_dim, out_dim, |
|
min_layers_num=3, |
|
residual=True, |
|
norm='ln' |
|
) |
|
self.grus = nn.GRU(out_dim, out_dim, 1, batch_first=True) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward(self, spectrogram, pre_state=None, id=None, time_steps=None): |
|
|
|
spectrogram = spectrogram |
|
spectrogram = self.dropout(spectrogram) |
|
if self.identity: |
|
id = id.reshape(id.shape[0], -1, 1).repeat(1, 1, spectrogram.shape[2]).to(torch.float32) |
|
id = self.id_mlp(id) |
|
spectrogram = torch.cat([spectrogram, id], dim=1) |
|
x1 = self.first_net(spectrogram) |
|
if time_steps is not None: |
|
x1 = F.interpolate(x1, size=time_steps, align_corners=False, mode='linear') |
|
|
|
|
|
|
|
hidden_state=None |
|
|
|
return x1, hidden_state |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, |
|
n_poses, |
|
each_dim: list, |
|
dim_list: list, |
|
training=False, |
|
device=None, |
|
identity=True, |
|
num_classes=0, |
|
): |
|
super().__init__() |
|
|
|
self.training = training |
|
self.device = device |
|
self.gen_length = n_poses |
|
self.identity = identity |
|
|
|
norm = 'ln' |
|
in_dim = 256 |
|
out_dim = 256 |
|
|
|
self.encoder_choice = 'faceformer' |
|
|
|
if self.encoder_choice == 'meshtalk': |
|
self.audio_encoder = MeshtalkEncoder(latent_dim=in_dim) |
|
elif self.encoder_choice == 'faceformer': |
|
|
|
self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") |
|
self.audio_encoder.feature_extractor._freeze_parameters() |
|
self.audio_feature_map = nn.Linear(768, in_dim) |
|
else: |
|
self.audio_encoder = AudioEncoder(in_dim=64, out_dim=out_dim) |
|
|
|
self.audio_middle = AudioEncoder(in_dim, out_dim, identity, num_classes) |
|
|
|
self.dim_list = dim_list |
|
|
|
self.decoder = nn.ModuleList() |
|
self.final_out = nn.ModuleList() |
|
|
|
self.decoder.append(nn.Sequential( |
|
ConvNormRelu(out_dim, 64, norm=norm), |
|
ConvNormRelu(64, 64, norm=norm), |
|
ConvNormRelu(64, 64, norm=norm), |
|
)) |
|
self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) |
|
|
|
self.decoder.append(nn.Sequential( |
|
ConvNormRelu(out_dim, out_dim, norm=norm), |
|
ConvNormRelu(out_dim, out_dim, norm=norm), |
|
ConvNormRelu(out_dim, out_dim, norm=norm), |
|
)) |
|
self.final_out.append(nn.Conv1d(out_dim, each_dim[3], 1, 1)) |
|
|
|
def forward(self, in_spec, gt_poses=None, id=None, pre_state=None, time_steps=None): |
|
if self.training: |
|
time_steps = gt_poses.shape[1] |
|
|
|
|
|
if self.encoder_choice == 'meshtalk': |
|
in_spec = audio_chunking(in_spec.squeeze(-1), frame_rate=30, chunk_size=16000) |
|
feature = self.audio_encoder(in_spec.unsqueeze(0))["code"].transpose(1, 2) |
|
elif self.encoder_choice == 'faceformer': |
|
hidden_states = self.audio_encoder(in_spec.reshape(in_spec.shape[0], -1), frame_num=time_steps).last_hidden_state |
|
feature = self.audio_feature_map(hidden_states).transpose(1, 2) |
|
else: |
|
feature, hidden_state = self.audio_encoder(in_spec, pre_state, time_steps=time_steps) |
|
|
|
|
|
|
|
feature, _ = self.audio_middle(feature, id=id) |
|
|
|
out = [] |
|
|
|
for i in range(self.decoder.__len__()): |
|
mid = self.decoder[i](feature) |
|
mid = self.final_out[i](mid) |
|
out.append(mid) |
|
|
|
out = torch.cat(out, dim=1) |
|
out = out.transpose(1, 2) |
|
|
|
return out, None |
|
|
|
|
|
|