Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Linear(nn.Module): | |
"""Linear layer with a specific initialization. | |
Args: | |
in_features (int): number of channels in the input tensor. | |
out_features (int): number of channels in the output tensor. | |
bias (bool, optional): enable/disable bias in the layer. Defaults to True. | |
init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. | |
""" | |
def __init__(self, in_features, out_features, bias=True, init_gain="linear"): | |
super().__init__() | |
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) | |
self._init_w(init_gain) | |
def _init_w(self, init_gain): | |
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) | |
def forward(self, x): | |
return self.linear_layer(x) | |
class LinearBN(nn.Module): | |
"""Linear layer with Batch Normalization. | |
x -> linear -> BN -> o | |
Args: | |
in_features (int): number of channels in the input tensor. | |
out_features (int ): number of channels in the output tensor. | |
bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. | |
init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. | |
""" | |
def __init__(self, in_features, out_features, bias=True, init_gain="linear"): | |
super().__init__() | |
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) | |
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) | |
self._init_w(init_gain) | |
def _init_w(self, init_gain): | |
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) | |
def forward(self, x): | |
""" | |
Shapes: | |
x: [T, B, C] or [B, C] | |
""" | |
out = self.linear_layer(x) | |
if len(out.shape) == 3: | |
out = out.permute(1, 2, 0) | |
out = self.batch_normalization(out) | |
if len(out.shape) == 3: | |
out = out.permute(2, 0, 1) | |
return out | |
class Prenet(nn.Module): | |
"""Tacotron specific Prenet with an optional Batch Normalization. | |
Note: | |
Prenet with BN improves the model performance significantly especially | |
if it is enabled after learning a diagonal attention alignment with the original | |
prenet. However, if the target dataset is high quality then it also works from | |
the start. It is also suggested to disable dropout if BN is in use. | |
prenet_type == "original" | |
x -> [linear -> ReLU -> Dropout]xN -> o | |
prenet_type == "bn" | |
x -> [linear -> BN -> ReLU -> Dropout]xN -> o | |
Args: | |
in_features (int): number of channels in the input tensor and the inner layers. | |
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". | |
prenet_dropout (bool, optional): dropout rate. Defaults to True. | |
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. | |
out_features (list, optional): List of output channels for each prenet block. | |
It also defines number of the prenet blocks based on the length of argument list. | |
Defaults to [256, 256]. | |
bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. | |
""" | |
# pylint: disable=dangerous-default-value | |
def __init__( | |
self, | |
in_features, | |
prenet_type="original", | |
prenet_dropout=True, | |
dropout_at_inference=False, | |
out_features=[256, 256], | |
bias=True, | |
): | |
super().__init__() | |
self.prenet_type = prenet_type | |
self.prenet_dropout = prenet_dropout | |
self.dropout_at_inference = dropout_at_inference | |
in_features = [in_features] + out_features[:-1] | |
if prenet_type == "bn": | |
self.linear_layers = nn.ModuleList( | |
[LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] | |
) | |
elif prenet_type == "original": | |
self.linear_layers = nn.ModuleList( | |
[Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] | |
) | |
def forward(self, x): | |
for linear in self.linear_layers: | |
if self.prenet_dropout: | |
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) | |
else: | |
x = F.relu(linear(x)) | |
return x | |