File size: 3,475 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch import nn
from torch.nn import Module

from models.tts.delightful_tts.conv_blocks import Conv1dGLU

from .conformer_conv_module import ConformerConvModule
from .conformer_multi_headed_self_attention import ConformerMultiHeadedSelfAttention
from .feed_forward import FeedForward


class ConformerBlock(Module):
    r"""ConformerBlock class represents a block in the Conformer model architecture.
    The block includes a pointwise convolution followed by Gated Linear Units (`GLU`) activation layer (`Conv1dGLU`),
    a Conformer self attention layer (`ConformerMultiHeadedSelfAttention`), and optional feed-forward layer (`FeedForward`).

    Args:
        d_model (int): The number of expected features in the input.
        n_head (int): The number of heads for the multiheaded attention mechanism.
        kernel_size_conv_mod (int): The size of the convolving kernel for the convolution module.
        embedding_dim (int): The dimension of the embeddings.
        dropout (float): The dropout probability.
        with_ff (bool): If True, uses FeedForward layer inside ConformerBlock.
    """

    def __init__(
        self,
        d_model: int,
        n_head: int,
        kernel_size_conv_mod: int,
        embedding_dim: int,
        dropout: float,
        with_ff: bool,
    ):
        super().__init__()
        self.with_ff = with_ff
        self.conditioning = Conv1dGLU(
            d_model=d_model,
            kernel_size=kernel_size_conv_mod,
            padding=kernel_size_conv_mod // 2,
            embedding_dim=embedding_dim,
        )
        if self.with_ff:
            self.ff = FeedForward(
                d_model=d_model,
                dropout=dropout,
                kernel_size=3,
            )
        self.conformer_conv_1 = ConformerConvModule(
            d_model,
            kernel_size=kernel_size_conv_mod,
            dropout=dropout,
        )
        self.ln = nn.LayerNorm(
            d_model,
        )
        self.slf_attn = ConformerMultiHeadedSelfAttention(
            d_model=d_model,
            num_heads=n_head,
            dropout_p=dropout,
        )
        self.conformer_conv_2 = ConformerConvModule(
            d_model,
            kernel_size=kernel_size_conv_mod,
            dropout=dropout,
        )

    def forward(
        self,
        x: torch.Tensor,
        embeddings: torch.Tensor,
        mask: torch.Tensor,
        slf_attn_mask: torch.Tensor,
        encoding: torch.Tensor,
    ) -> torch.Tensor:
        r"""Forward pass of the Conformer block.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).
            embeddings (Tensor): Embeddings tensor.
            mask (Tensor): The mask tensor.
            slf_attn_mask (Tensor): The mask for self-attention layer.
            encoding (Tensor): The positional encoding tensor.

        Returns:
            Tensor: The output tensor of shape (batch_size, seq_len, num_features).
        """
        x = self.conditioning.forward(x, embeddings=embeddings)
        if self.with_ff:
            x = self.ff(x) + x
        x = self.conformer_conv_1(x) + x
        res = x
        x = self.ln(x)
        x, _ = self.slf_attn(
            query=x,
            key=x,
            value=x,
            mask=slf_attn_mask,
            encoding=encoding,
        )
        x = x + res
        x = x.masked_fill(mask.unsqueeze(-1), 0)
        return self.conformer_conv_2(x) + x