Spaces:
Running
Running
import torch | |
from torch import nn | |
from torch.nn import Module | |
from models.config import AcousticModelConfigType | |
from models.tts.delightful_tts.attention import StyleEmbedAttention | |
class STL(Module): | |
r"""Style Token Layer (STL). | |
This layer helps to encapsulate different speaking styles in token embeddings. | |
Args: | |
model_config (AcousticModelConfigType): An object containing the model's configuration parameters. | |
Attributes: | |
embed (nn.Parameter): The style token embedding tensor. | |
attention (StyleEmbedAttention): The attention module used to compute a weighted sum of embeddings. | |
""" | |
def __init__( | |
self, | |
model_config: AcousticModelConfigType, | |
): | |
super().__init__() | |
# Number of attention heads | |
num_heads = 1 | |
# Dimension of encoder hidden states | |
n_hidden = model_config.encoder.n_hidden | |
# Number of style tokens | |
self.token_num = model_config.reference_encoder.token_num | |
# Define a learnable tensor for style tokens embedding | |
self.embed = nn.Parameter( | |
torch.FloatTensor(self.token_num, n_hidden // num_heads), | |
) | |
# Dimension of query in attention | |
d_q = n_hidden // 2 | |
# Dimension of keys in attention | |
d_k = n_hidden // num_heads | |
# Style Embedding Attention module | |
self.attention = StyleEmbedAttention( | |
query_dim=d_q, | |
key_dim=d_k, | |
num_units=n_hidden, | |
num_heads=num_heads, | |
) | |
# Initialize the embedding with normal distribution | |
torch.nn.init.normal_(self.embed, mean=0, std=0.5) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r"""Forward pass of the Style Token Layer | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns | |
torch.Tensor: The emotion embedded tensor after applying attention mechanism. | |
""" | |
N = x.size(0) | |
# Reshape input tensor to [N, 1, n_hidden // 2] | |
query = x.unsqueeze(1) | |
keys_soft = ( | |
torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) | |
) # [N, token_num, n_hidden // num_heads] | |
# Apply attention mechanism to get weighted sum of style token embeddings | |
return self.attention(query, keys_soft) | |