Spaces:
Running
Running
File size: 5,834 Bytes
3aa4060 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import math
import torch
from grad.ssim import SSIM
from grad.base import BaseModule
from grad.encoder import TextEncoder
from grad.diffusion import Diffusion
from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments
SpeakerLoss = torch.nn.CosineEmbeddingLoss()
SsimLoss = SSIM()
class GradTTS(BaseModule):
def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs,
n_enc_channels, filter_channels,
dec_dim, beta_min, beta_max, pe_scale):
super(GradTTS, self).__init__()
# common
self.n_mels = n_mels
self.n_vecs = n_vecs
self.n_spks = n_spks
self.n_embs = n_embs
# encoder
self.n_enc_channels = n_enc_channels
self.filter_channels = filter_channels
# decoder
self.dec_dim = dec_dim
self.beta_min = beta_min
self.beta_max = beta_max
self.pe_scale = pe_scale
self.pit_emb = torch.nn.Embedding(n_pits, n_embs)
self.spk_emb = torch.nn.Linear(n_spks, n_embs)
self.encoder = TextEncoder(n_vecs,
n_mels,
n_embs,
n_enc_channels,
filter_channels)
self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale)
def fine_tune(self):
for p in self.pit_emb.parameters():
p.requires_grad = False
for p in self.spk_emb.parameters():
p.requires_grad = False
self.encoder.fine_tune()
@torch.no_grad()
def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False):
"""
Generates mel-spectrogram from vec. Returns:
1. encoder outputs
2. decoder outputs
Args:
lengths (torch.Tensor): lengths of texts in batch.
vec (torch.Tensor): batch of speech vec
pit (torch.Tensor): batch of speech pit
spk (torch.Tensor): batch of speaker
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
Usually, does not provide synthesis improvements.
"""
lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk])
# Get pitch embedding
pit = self.pit_emb(f0_to_coarse(pit))
# Get speaker embedding
spk = self.spk_emb(spk)
# Transpose
vec = torch.transpose(vec, 1, -1)
pit = torch.transpose(pit, 1, -1)
# Get encoder_outputs `mu_x`
mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk)
encoder_outputs = mu_x
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature
# Generate sample by performing reverse dynamics
decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc)
encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs)
return encoder_outputs, decoder_outputs
def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False):
"""
Computes 2 losses:
1. prior loss: loss between mel-spectrogram and encoder outputs.
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
Args:
lengths (torch.Tensor): lengths of texts in batch.
vec (torch.Tensor): batch of speech vec
pit (torch.Tensor): batch of speech pit
spk (torch.Tensor): batch of speaker
mel (torch.Tensor): batch of corresponding mel-spectrogram
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
"""
lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel])
# Get pitch embedding
pit = self.pit_emb(f0_to_coarse(pit))
# Get speaker embedding
spk_64 = self.spk_emb(spk)
# Transpose
vec = torch.transpose(vec, 1, -1)
pit = torch.transpose(pit, 1, -1)
# Get encoder_outputs `mu_x`
mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True)
# Compute loss between aligned encoder outputs and mel-spectrogram
prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x)
prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels)
# Mel ssim
mel_loss = SsimLoss(mu_x, mel, mask_x)
# Compute loss of speaker for GRL
spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0))
.to(spk.device).fill_(1.0))
# Compute loss of score-based decoder
if skip_diff:
diff_loss = prior_loss.clone()
diff_loss.fill_(0)
else:
# Cut a small segment of mel-spectrogram in order to increase batch size
if not isinstance(out_size, type(None)):
ids = rand_ids_segments(lengths, out_size)
mel = slice_segments(mel, ids, out_size)
mask_y = slice_segments(mask_x, ids, out_size)
mu_y = slice_segments(mu_x, ids, out_size)
mu_y = mu_y + torch.randn_like(mu_y)
diff_loss, xt = self.decoder.compute_loss(
spk_64, mel, mask_y, mu_y)
return prior_loss, diff_loss, mel_loss, spk_loss
|