Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder | |
from TTS.tts.utils.helpers import sequence_mask | |
class Decoder(nn.Module): | |
"""Uses glow decoder with some modifications. | |
:: | |
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze | |
Args: | |
in_channels (int): channels of input tensor. | |
hidden_channels (int): hidden decoder channels. | |
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) | |
dilation_rate (int): rate to increase dilation by each layer in a decoder block. | |
num_flow_blocks (int): number of decoder blocks. | |
num_coupling_layers (int): number coupling layers. (number of wavenet layers.) | |
dropout_p (float): wavenet dropout rate. | |
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels, | |
kernel_size, | |
dilation_rate, | |
num_flow_blocks, | |
num_coupling_layers, | |
dropout_p=0.0, | |
num_splits=4, | |
num_squeeze=2, | |
sigmoid_scale=False, | |
c_in_channels=0, | |
): | |
super().__init__() | |
self.glow_decoder = GlowDecoder( | |
in_channels, | |
hidden_channels, | |
kernel_size, | |
dilation_rate, | |
num_flow_blocks, | |
num_coupling_layers, | |
dropout_p, | |
num_splits, | |
num_squeeze, | |
sigmoid_scale, | |
c_in_channels, | |
) | |
self.n_sqz = num_squeeze | |
def forward(self, x, x_len, g=None, reverse=False): | |
""" | |
Input shapes: | |
- x: :math:`[B, C, T]` | |
- x_len :math:`[B]` | |
- g: :math:`[B, C]` | |
Output shapes: | |
- x: :math:`[B, C, T]` | |
- x_len :math:`[B]` | |
- logget_tot :math:`[B]` | |
""" | |
x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) | |
x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) | |
x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) | |
return x, x_len, logdet_tot | |
def preprocess(self, y, y_lengths, y_max_length): | |
if y_max_length is not None: | |
y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz | |
y = y[:, :, :y_max_length] | |
y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz | |
return y, y_lengths, y_max_length | |
def store_inverse(self): | |
self.glow_decoder.store_inverse() | |