duyv's picture
Upload 381 files
813828b verified
import torch
from torch import nn
from TTS.encoder.models.base_encoder import BaseEncoder
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.proj_size = proj_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
def forward(self, x):
self.lstm.flatten_parameters()
o, (_, _) = self.lstm(x)
return self.linear(o)
class LSTMWithoutProjection(nn.Module):
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
super().__init__()
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
_, (hidden, _) = self.lstm(x)
return self.relu(self.linear(hidden[-1]))
class LSTMSpeakerEncoder(BaseEncoder):
def __init__(
self,
input_dim,
proj_dim=256,
lstm_dim=768,
num_lstm_layers=3,
use_lstm_with_projection=True,
use_torch_spec=False,
audio_config=None,
):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
layers = []
# choise LSTM layer
if use_lstm_with_projection:
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
for _ in range(num_lstm_layers - 1):
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
self.layers = nn.Sequential(*layers)
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
else:
self.torch_spec = None
self._init_layers()
def _init_layers(self):
for name, param in self.layers.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0.0)
elif "weight" in name:
nn.init.xavier_normal_(param)
def forward(self, x, l2_norm=True):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x.squeeze_(1)
x = self.torch_spec(x)
x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
d = d[:, -1]
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d