Spaces:
Running
Running
File size: 9,388 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
from torch.nn import functional as F
class CapacitronVAE(nn.Module):
"""Effective Use of Variational Embedding Capacity for prosody transfer.
See https://arxiv.org/abs/1906.03402"""
def __init__(
self,
num_mel,
capacitron_VAE_embedding_dim,
encoder_output_dim=256,
reference_encoder_out_dim=128,
speaker_embedding_dim=None,
text_summary_embedding_dim=None,
):
super().__init__()
# Init distributions
self.prior_distribution = MVN(
torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim)
)
self.approximate_posterior_distribution = None
# define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim
self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim)
# Init beta, the lagrange-like term for the KL distribution
self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True)
mlp_input_dimension = reference_encoder_out_dim
if text_summary_embedding_dim is not None:
self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim)
mlp_input_dimension += text_summary_embedding_dim
if speaker_embedding_dim is not None:
# TODO: Test a multispeaker model!
mlp_input_dimension += speaker_embedding_dim
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
# Use reference
if reference_mel_info is not None:
reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels]
mel_lengths = reference_mel_info[1] # [batch_size]
enc_out = self.encoder(reference_mels, mel_lengths)
# concat speaker_embedding and/or text summary embedding
if text_info is not None:
text_inputs = text_info[0] # [batch_size, num_characters, num_embedding]
input_lengths = text_info[1]
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
if speaker_embedding is not None:
speaker_embedding = torch.squeeze(speaker_embedding)
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
# Feed the output of the ref encoder and information about text/speaker into
# an MLP to produce the parameteres for the approximate poterior distributions
mu, sigma = self.post_encoder_mlp(enc_out)
# convert to cpu because prior_distribution was created on cpu
mu = mu.cpu()
sigma = sigma.cpu()
# Sample from the posterior: z ~ q(z|x)
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
VAE_embedding = self.approximate_posterior_distribution.rsample()
# Infer from the model, bypasses encoding
else:
# Sample from the prior: z ~ p(z)
VAE_embedding = self.prior_distribution.sample().unsqueeze(0)
# reshape to [batch_size, 1, capacitron_VAE_embedding_dim]
return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, out_dim):
super().__init__()
self.num_mel = num_mel
filters = [1] + [32, 32, 64, 64, 128, 128]
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class TextSummary(nn.Module):
def __init__(self, embedding_dim, encoder_output_dim):
super().__init__()
self.lstm = nn.LSTM(
encoder_output_dim, # text embedding dimension from the text encoder
embedding_dim, # fixed length output summary the lstm creates from the input
batch_first=True,
bidirectional=False,
)
def forward(self, inputs, input_lengths):
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
packed_seqs = nn.utils.rnn.pack_padded_sequence(
inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.lstm.flatten_parameters()
_, (ht, _) = self.lstm(packed_seqs)
last_output = ht[-1]
return last_output
class PostEncoderMLP(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
modules = [
nn.Linear(input_size, hidden_size), # Hidden Layer
nn.Tanh(),
nn.Linear(hidden_size, hidden_size * 2),
] # Output layer twice the size for mean and variance
self.net = nn.Sequential(*modules)
self.softplus = nn.Softplus()
def forward(self, _input):
mlp_output = self.net(_input)
# The mean parameter is unconstrained
mu = mlp_output[:, : self.hidden_size]
# The standard deviation must be positive. Parameterise with a softplus
sigma = self.softplus(mlp_output[:, self.hidden_size :])
return mu, sigma
|