Spaces:
Sleeping
Sleeping
File size: 2,341 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from torch import nn
from ..generic.normalization import LayerNorm
class DurationPredictor(nn.Module):
"""Glow-TTS duration prediction model.
::
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
Args:
in_channels (int): Number of channels of the input tensor.
hidden_channels (int): Number of hidden channels of the network.
kernel_size (int): Kernel size for the conv layers.
dropout_p (float): Dropout rate used after each conv layer.
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
super().__init__()
# add language embedding dim in the input
if language_emb_dim:
in_channels += language_emb_dim
# class arguments
self.in_channels = in_channels
self.filter_channels = hidden_channels
self.kernel_size = kernel_size
self.dropout_p = dropout_p
# layers
self.drop = nn.Dropout(dropout_p)
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(hidden_channels)
self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(hidden_channels)
# output layer
self.proj = nn.Conv1d(hidden_channels, 1, 1)
if cond_channels is not None and cond_channels != 0:
self.cond = nn.Conv1d(cond_channels, in_channels, 1)
if language_emb_dim != 0 and language_emb_dim is not None:
self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1)
def forward(self, x, x_mask, g=None, lang_emb=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if g is not None:
x = x + self.cond(g)
if lang_emb is not None:
x = x + self.cond_lang(lang_emb)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
|