Spaces:
Running
Running
import torch | |
from torch import nn | |
from torch.nn import Module | |
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE | |
class GLUActivation(Module): | |
r"""Implements the Gated Linear Unit (GLU) activation function. | |
The GLU activation splits the input in half across the channel dimension. | |
One half is passed through a nonlinear activation function (like sigmoid or leaky ReLU), | |
and the output from this activation function is used as a gate to control the | |
amplitude of the other half of the input. An element-wise multiplication is then performed | |
between the gating signal and the other half of the input. | |
The GLU activation allows the model to dynamically choose which inputs to pass through and | |
what information to suppress, which can help improving the model performance on certain tasks. | |
Args: | |
slope: Controls the slope for the leaky ReLU activation function. Default: 0.3 or see the const `LEAKY_RELU_SLOPE` | |
Shape: | |
- Input: (N, 2*C, L) where C is the number of input channels. | |
- Output: (N, C, L) | |
Examples: | |
```python | |
m = GLUActivation(0.3) | |
input = torch.randn(16, 2*20, 44) | |
output = m(input) | |
``` | |
""" | |
def __init__(self, slope: float = LEAKY_RELU_SLOPE): | |
super().__init__() | |
self.lrelu = nn.LeakyReLU(slope) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Defines the computation performed at every call. | |
Args: | |
x: The input tensor of shape (batch_size, 2*channels, signal_length) | |
Returns: | |
x: The output tensor of shape (batch_size, channels, signal_length) | |
""" | |
# Split the input into two equal parts (chunks) along dimension 1 | |
out, gate = x.chunk(2, dim=1) | |
# Perform element-wise multiplication of the first half (out) | |
# with the result of applying LeakyReLU on the second half (gate) | |
return out * self.lrelu(gate) | |