Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import einsum | |
from tortoise.models.arch_util import AttentionBlock | |
from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder | |
def exists(val): | |
return val is not None | |
def masked_mean(t, mask): | |
t = t.masked_fill(~mask, 0.0) | |
return t.sum(dim=1) / mask.sum(dim=1) | |
class CollapsingTransformer(nn.Module): | |
def __init__( | |
self, | |
model_dim, | |
output_dims, | |
heads, | |
dropout, | |
depth, | |
mask_percentage=0, | |
**encoder_kwargs | |
): | |
super().__init__() | |
self.transformer = ContinuousTransformerWrapper( | |
max_seq_len=-1, | |
use_pos_emb=False, | |
attn_layers=Encoder( | |
dim=model_dim, | |
depth=depth, | |
heads=heads, | |
ff_dropout=dropout, | |
ff_mult=1, | |
attn_dropout=dropout, | |
use_rmsnorm=True, | |
ff_glu=True, | |
rotary_pos_emb=True, | |
**encoder_kwargs, | |
), | |
) | |
self.pre_combiner = nn.Sequential( | |
nn.Conv1d(model_dim, output_dims, 1), | |
AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False), | |
nn.Conv1d(output_dims, output_dims, 1), | |
) | |
self.mask_percentage = mask_percentage | |
def forward(self, x, **transformer_kwargs): | |
h = self.transformer(x, **transformer_kwargs) | |
h = h.permute(0, 2, 1) | |
h = self.pre_combiner(h).permute(0, 2, 1) | |
if self.training: | |
mask = torch.rand_like(h.float()) > self.mask_percentage | |
else: | |
mask = torch.ones_like(h.float()).bool() | |
return masked_mean(h, mask) | |
class ConvFormatEmbedding(nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.emb = nn.Embedding(*args, **kwargs) | |
def forward(self, x): | |
y = self.emb(x) | |
return y.permute(0, 2, 1) | |
class CVVP(nn.Module): | |
def __init__( | |
self, | |
model_dim=512, | |
transformer_heads=8, | |
dropout=0.1, | |
conditioning_enc_depth=8, | |
cond_mask_percentage=0, | |
mel_channels=80, | |
mel_codes=None, | |
speech_enc_depth=8, | |
speech_mask_percentage=0, | |
latent_multiplier=1, | |
): | |
super().__init__() | |
latent_dim = latent_multiplier * model_dim | |
self.temperature = nn.Parameter(torch.tensor(1.0)) | |
self.cond_emb = nn.Sequential( | |
nn.Conv1d(mel_channels, model_dim // 2, kernel_size=5, stride=2, padding=2), | |
nn.Conv1d(model_dim // 2, model_dim, kernel_size=3, stride=2, padding=1), | |
) | |
self.conditioning_transformer = CollapsingTransformer( | |
model_dim, | |
model_dim, | |
transformer_heads, | |
dropout, | |
conditioning_enc_depth, | |
cond_mask_percentage, | |
) | |
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) | |
if mel_codes is None: | |
self.speech_emb = nn.Conv1d( | |
mel_channels, model_dim, kernel_size=5, padding=2 | |
) | |
else: | |
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) | |
self.speech_transformer = CollapsingTransformer( | |
model_dim, | |
latent_dim, | |
transformer_heads, | |
dropout, | |
speech_enc_depth, | |
speech_mask_percentage, | |
) | |
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False) | |
def get_grad_norm_parameter_groups(self): | |
return { | |
"conditioning": list(self.conditioning_transformer.parameters()), | |
"speech": list(self.speech_transformer.parameters()), | |
} | |
def forward(self, mel_cond, mel_input, return_loss=False): | |
cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) | |
enc_cond = self.conditioning_transformer(cond_emb) | |
cond_latents = self.to_conditioning_latent(enc_cond) | |
speech_emb = self.speech_emb(mel_input).permute(0, 2, 1) | |
enc_speech = self.speech_transformer(speech_emb) | |
speech_latents = self.to_speech_latent(enc_speech) | |
cond_latents, speech_latents = map( | |
lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents) | |
) | |
temp = self.temperature.exp() | |
if not return_loss: | |
sim = einsum("n d, n d -> n", cond_latents, speech_latents) * temp | |
return sim | |
sim = einsum("i d, j d -> i j", cond_latents, speech_latents) * temp | |
labels = torch.arange(cond_latents.shape[0], device=mel_input.device) | |
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 | |
return loss | |
if __name__ == "__main__": | |
clvp = CVVP() | |
clvp(torch.randn(2, 80, 100), torch.randn(2, 80, 95), return_loss=True) | |