nickovchinnikov's picture
Init
9d61c9b
import torch
from torch import nn
from torch.nn import Module
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE
class GLUActivation(Module):
r"""Implements the Gated Linear Unit (GLU) activation function.
The GLU activation splits the input in half across the channel dimension.
One half is passed through a nonlinear activation function (like sigmoid or leaky ReLU),
and the output from this activation function is used as a gate to control the
amplitude of the other half of the input. An element-wise multiplication is then performed
between the gating signal and the other half of the input.
The GLU activation allows the model to dynamically choose which inputs to pass through and
what information to suppress, which can help improving the model performance on certain tasks.
Args:
slope: Controls the slope for the leaky ReLU activation function. Default: 0.3 or see the const `LEAKY_RELU_SLOPE`
Shape:
- Input: (N, 2*C, L) where C is the number of input channels.
- Output: (N, C, L)
Examples:
```python
m = GLUActivation(0.3)
input = torch.randn(16, 2*20, 44)
output = m(input)
```
"""
def __init__(self, slope: float = LEAKY_RELU_SLOPE):
super().__init__()
self.lrelu = nn.LeakyReLU(slope)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call.
Args:
x: The input tensor of shape (batch_size, 2*channels, signal_length)
Returns:
x: The output tensor of shape (batch_size, channels, signal_length)
"""
# Split the input into two equal parts (chunks) along dimension 1
out, gate = x.chunk(2, dim=1)
# Perform element-wise multiplication of the first half (out)
# with the result of applying LeakyReLU on the second half (gate)
return out * self.lrelu(gate)