import torch from torch import nn from torch.nn import Module from models.tts.delightful_tts.conv_blocks import Conv1dGLU from .conformer_conv_module import ConformerConvModule from .conformer_multi_headed_self_attention import ConformerMultiHeadedSelfAttention from .feed_forward import FeedForward class ConformerBlock(Module): r"""ConformerBlock class represents a block in the Conformer model architecture. The block includes a pointwise convolution followed by Gated Linear Units (`GLU`) activation layer (`Conv1dGLU`), a Conformer self attention layer (`ConformerMultiHeadedSelfAttention`), and optional feed-forward layer (`FeedForward`). Args: d_model (int): The number of expected features in the input. n_head (int): The number of heads for the multiheaded attention mechanism. kernel_size_conv_mod (int): The size of the convolving kernel for the convolution module. embedding_dim (int): The dimension of the embeddings. dropout (float): The dropout probability. with_ff (bool): If True, uses FeedForward layer inside ConformerBlock. """ def __init__( self, d_model: int, n_head: int, kernel_size_conv_mod: int, embedding_dim: int, dropout: float, with_ff: bool, ): super().__init__() self.with_ff = with_ff self.conditioning = Conv1dGLU( d_model=d_model, kernel_size=kernel_size_conv_mod, padding=kernel_size_conv_mod // 2, embedding_dim=embedding_dim, ) if self.with_ff: self.ff = FeedForward( d_model=d_model, dropout=dropout, kernel_size=3, ) self.conformer_conv_1 = ConformerConvModule( d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, ) self.ln = nn.LayerNorm( d_model, ) self.slf_attn = ConformerMultiHeadedSelfAttention( d_model=d_model, num_heads=n_head, dropout_p=dropout, ) self.conformer_conv_2 = ConformerConvModule( d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, ) def forward( self, x: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor, slf_attn_mask: torch.Tensor, encoding: torch.Tensor, ) -> torch.Tensor: r"""Forward pass of the Conformer block. Args: x (Tensor): Input tensor of shape (batch_size, seq_len, num_features). embeddings (Tensor): Embeddings tensor. mask (Tensor): The mask tensor. slf_attn_mask (Tensor): The mask for self-attention layer. encoding (Tensor): The positional encoding tensor. Returns: Tensor: The output tensor of shape (batch_size, seq_len, num_features). """ x = self.conditioning.forward(x, embeddings=embeddings) if self.with_ff: x = self.ff(x) + x x = self.conformer_conv_1(x) + x res = x x = self.ln(x) x, _ = self.slf_attn( query=x, key=x, value=x, mask=slf_attn_mask, encoding=encoding, ) x = x + res x = x.masked_fill(mask.unsqueeze(-1), 0) return self.conformer_conv_2(x) + x