text_to_speech / mtts /models /decoder /fs2_transformer_decoder.py
wuxulong19950206
First model version
14d1720
import torch.nn as nn
from torch import Tensor
from ..layers import FFTBlock, get_sinusoid_encoding_table
class FS2TransformerDecoder(nn.Module):
""" A decoder that accepts a list of sequences as input and
out a sequence as output. The input and output sequences share the same length
The input sequence is a list of tensors, which may contain text-embedding, speaker-embeddings.
"""
def __init__(
self,
input_dim: int = 256, # must == decoder output dim
n_layers: int = 4,
n_heads: int = 2,
hidden_dim: int = 256,
d_inner: int = 1024,
dropout: float = 0.5,
max_len: int = 2048, # for computing position table
):
super(FS2TransformerDecoder, self).__init__()
self.input_dim = input_dim
self.input_dim = input_dim
self.max_len = max_len
d_k = hidden_dim // n_heads
d_v = hidden_dim // n_heads
n_position = max_len + 1
# self.speaker_fc = nn.Linear(512, 256, bias=False)
pos_table = get_sinusoid_encoding_table(n_position, input_dim).unsqueeze(0)
self.position_enc = nn.Parameter(pos_table, requires_grad=False)
layers = []
for _ in range(n_layers):
layer = FFTBlock(hidden_dim, d_inner, n_heads, d_k, d_v, dropout=dropout)
layers.append(layer)
self.layers = nn.ModuleList(layers)
def forward(self, input: Tensor, mask):
batch_size, seq_len, input_dim = input.shape[0:3]
if input.shape[1] != seq_len:
raise ValueError('The input sequences must have the same length')
if input.shape[1] != seq_len:
raise ValueError('The input sequences must have the same dimension')
attn_mask = mask.unsqueeze(1).expand(-1, seq_len, -1)
if input.shape[1] > self.max_len:
raise ValueError('inputs.shape[1] > self.max_len')
pos_embed = self.position_enc[:, :seq_len, :].expand(batch_size, -1, -1)
output = input + pos_embed
for layer in self.layers:
output, dec_slf_attn = layer(output, mask=mask, attn_mask=attn_mask)
return output