text_to_speech / mtts /models /encoder /fs2_transformer_encoder.py
wuxulong19950206
First model version
14d1720
from typing import List
import torch.nn as nn
from torch import Tensor
from ..layers import FFTBlock, get_sinusoid_encoding_table
class FS2TransformerEncoder(nn.Module):
''' FS2TransformerEncoder '''
def __init__(
self,
emb_layers: nn.ModuleList,
embeding_weights: List[float],
hidden_dim: int = 256,
n_layers: int = 4,
n_heads: int = 2,
d_inner: int = 1024,
dropout: float = 0.5,
max_len: int = 1024,
):
super(FS2TransformerEncoder, self).__init__()
self.emb_layers = emb_layers
self.embeding_weights = embeding_weights
self.hidden_dim = hidden_dim
d_k = hidden_dim // n_heads
d_v = hidden_dim // n_heads
d_model = hidden_dim
self.layers = nn.ModuleList()
for _ in range(n_layers):
layer = FFTBlock(d_model, d_inner, n_heads, d_k, d_v, dropout=dropout)
self.layers.append(layer)
def forward(self, texts: List[Tensor], mask: Tensor):
if len(self.embeding_weights) != len(texts):
raise ValueError(f'Input texts has length {len(texts)}, \
but embedding module list has length {len(self.embeding_weights)}')
batch_size = texts[0].shape[0]
seq_len = texts[0].shape[1]
attn_mask = mask.unsqueeze(1).expand(-1, seq_len, -1)
text_embed = self.emb_layers[0](texts[0]) * self.embeding_weights[0]
n_embs = len(self.embeding_weights)
for i in range(1, n_embs):
text_embed += self.emb_layers[i](texts[i]) * self.embeding_weights[i]
if self.training:
pos_embed = get_sinusoid_encoding_table(seq_len, self.hidden_dim)
assert pos_embed.shape[0] == seq_len
pos_embed = pos_embed[:seq_len, :]
pos_embed = pos_embed.unsqueeze(0).expand(batch_size, -1, -1)
pos_embed = pos_embed.to(texts[0].device)
else:
pos_embed = self.position_enc[:, :seq_len, :]
pos_embed = pos_embed.expand(batch_size, -1, -1)
all_embed = text_embed + pos_embed
for layer in self.layers:
all_embed, enc_slf_attn = layer(all_embed, mask=mask, attn_mask=attn_mask)
return all_embed