File size: 5,834 Bytes
3aa4060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import math
import torch

from grad.ssim import SSIM
from grad.base import BaseModule
from grad.encoder import TextEncoder
from grad.diffusion import Diffusion
from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments

SpeakerLoss = torch.nn.CosineEmbeddingLoss()
SsimLoss = SSIM()

class GradTTS(BaseModule):
    def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs, 
                 n_enc_channels, filter_channels, 
                 dec_dim, beta_min, beta_max, pe_scale):
        super(GradTTS, self).__init__()
        # common
        self.n_mels = n_mels
        self.n_vecs = n_vecs
        self.n_spks = n_spks
        self.n_embs = n_embs
        # encoder
        self.n_enc_channels = n_enc_channels
        self.filter_channels = filter_channels
        # decoder
        self.dec_dim = dec_dim
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.pe_scale = pe_scale

        self.pit_emb = torch.nn.Embedding(n_pits, n_embs)
        self.spk_emb = torch.nn.Linear(n_spks, n_embs)
        self.encoder = TextEncoder(n_vecs,
                                   n_mels,
                                   n_embs,
                                   n_enc_channels,
                                   filter_channels)
        self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale)

    def fine_tune(self):
        for p in self.pit_emb.parameters():
            p.requires_grad = False
        for p in self.spk_emb.parameters():
            p.requires_grad = False
        self.encoder.fine_tune()

    @torch.no_grad()
    def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False):
        """
        Generates mel-spectrogram from vec. Returns:
            1. encoder outputs
            2. decoder outputs

        Args:
            lengths (torch.Tensor): lengths of texts in batch.
            vec (torch.Tensor): batch of speech vec
            pit (torch.Tensor): batch of speech pit
            spk (torch.Tensor): batch of speaker
            
            n_timesteps (int): number of steps to use for reverse diffusion in decoder.
            temperature (float, optional): controls variance of terminal distribution.
            stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
                Usually, does not provide synthesis improvements.
        """
        lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk])

        # Get pitch embedding
        pit = self.pit_emb(f0_to_coarse(pit))

        # Get speaker embedding
        spk = self.spk_emb(spk)

        # Transpose
        vec = torch.transpose(vec, 1, -1)
        pit = torch.transpose(pit, 1, -1)

        # Get encoder_outputs `mu_x`
        mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk)
        encoder_outputs = mu_x

        # Sample latent representation from terminal distribution N(mu_y, I)
        z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature
        # Generate sample by performing reverse dynamics
        decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc)
        encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs)
        return encoder_outputs, decoder_outputs

    def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False):
        """
        Computes 2 losses:
            1. prior loss: loss between mel-spectrogram and encoder outputs.
            2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
            
        Args:
            lengths (torch.Tensor): lengths of texts in batch.
            vec (torch.Tensor): batch of speech vec
            pit (torch.Tensor): batch of speech pit
            spk (torch.Tensor): batch of speaker
            mel (torch.Tensor): batch of corresponding mel-spectrogram

            out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
                Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
        """
        lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel])

        # Get pitch embedding
        pit = self.pit_emb(f0_to_coarse(pit))

        # Get speaker embedding
        spk_64 = self.spk_emb(spk)

        # Transpose
        vec = torch.transpose(vec, 1, -1)
        pit = torch.transpose(pit, 1, -1)

        # Get encoder_outputs `mu_x`
        mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True)

        # Compute loss between aligned encoder outputs and mel-spectrogram
        prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x)
        prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels)

        # Mel ssim
        mel_loss = SsimLoss(mu_x, mel, mask_x)

        # Compute loss of speaker for GRL
        spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0))
                               .to(spk.device).fill_(1.0))

        # Compute loss of score-based decoder
        if skip_diff:
            diff_loss = prior_loss.clone()
            diff_loss.fill_(0)
        else:
            # Cut a small segment of mel-spectrogram in order to increase batch size
            if not isinstance(out_size, type(None)):
                ids = rand_ids_segments(lengths, out_size)
                mel = slice_segments(mel, ids, out_size)

                mask_y = slice_segments(mask_x, ids, out_size)
                mu_y = slice_segments(mu_x, ids, out_size)
                mu_y = mu_y + torch.randn_like(mu_y)

            diff_loss, xt = self.decoder.compute_loss(
                spk_64, mel, mask_y, mu_y)

        return prior_loss, diff_loss, mel_loss, spk_loss