Spaces:
Running
Running
File size: 2,351 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from torch import nn
from torch.nn import Module
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE
class FeedForward(Module):
r"""Creates a feed-forward neural network.
The network includes a layer normalization, an activation function (LeakyReLU), and dropout layers.
Args:
d_model (int): The number of expected features in the input.
kernel_size (int): The size of the convolving kernel for the first conv1d layer.
dropout (float): The dropout probability.
expansion_factor (int, optional): The expansion factor for the hidden layer size in the feed-forward network, default is 4.
leaky_relu_slope (float, optional): Controls the angle of the negative slope of LeakyReLU activation, default is `LEAKY_RELU_SLOPE`.
"""
def __init__(
self,
d_model: int,
kernel_size: int,
dropout: float,
expansion_factor: int = 4,
leaky_relu_slope: float = LEAKY_RELU_SLOPE,
):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(d_model)
self.conv_1 = nn.Conv1d(
d_model,
d_model * expansion_factor,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
self.act = nn.LeakyReLU(leaky_relu_slope)
self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Forward pass of the feed-forward neural network.
Args:
x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).
Returns:
Tensor: Output tensor of shape (batch_size, seq_len, num_features).
"""
# Apply layer normalization
x = self.ln(x)
# Forward pass through the first convolution layer, activation layer and dropout layer
x = x.permute((0, 2, 1))
x = self.conv_1(x)
x = x.permute((0, 2, 1))
x = self.act(x)
x = self.dropout(x)
# Forward pass through the second convolution layer and dropout layer
x = x.permute((0, 2, 1))
x = self.conv_2(x)
x = x.permute((0, 2, 1))
x = self.dropout(x)
# Scale the output by 0.5 (this helps with training stability)
return 0.5 * x
|