nickovchinnikov's picture
Init
9d61c9b
import torch
from torch import nn
from torch.nn import Module
from models.helpers import tools
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE
from models.tts.delightful_tts.conv_blocks import (
DepthWiseConv1d,
GLUActivation,
PointwiseConv1d,
)
class ConformerConvModule(Module):
r"""Conformer Convolution Module class represents a module in the Conformer model architecture.
The module includes a layer normalization, pointwise and depthwise convolutional layers,
Gated Linear Units (GLU) activation, and dropout layer.
Args:
d_model (int): The number of expected features in the input.
expansion_factor (int): The expansion factor for the hidden layer size in the feed-forward network, default is 2.
kernel_size (int): The size of the convolving kernel, default is 7.
dropout (float): The dropout probability, default is 0.1.
leaky_relu_slope (float): Controls the angle of the negative slope of the LeakyReLU activation, default is `LEAKY_RELU_SLOPE`.
"""
def __init__(
self,
d_model: int,
expansion_factor: int = 2,
kernel_size: int = 7,
dropout: float = 0.1,
leaky_relu_slope: float = LEAKY_RELU_SLOPE,
):
super().__init__()
inner_dim = d_model * expansion_factor
self.ln_1 = nn.LayerNorm(d_model)
self.conv_1 = PointwiseConv1d(
d_model,
inner_dim * 2,
)
self.conv_act = GLUActivation()
self.depthwise = DepthWiseConv1d(
inner_dim,
inner_dim,
kernel_size=kernel_size,
padding=tools.calc_same_padding(kernel_size)[0],
)
self.ln_2 = nn.GroupNorm(
1,
inner_dim,
)
self.activation = nn.LeakyReLU(leaky_relu_slope)
self.conv_2 = PointwiseConv1d(
inner_dim,
d_model,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Forward pass of the Conformer conv module.
Args:
x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).
Returns:
Tensor: The output tensor of shape (batch_size, seq_len, num_features).
"""
x = self.ln_1(x)
x = x.permute(0, 2, 1)
x = self.conv_1(x)
x = self.conv_act(x)
x = self.depthwise(x)
x = self.ln_2(x)
x = self.activation(x)
x = self.conv_2(x)
x = x.permute(0, 2, 1)
return self.dropout(x)