|
|
|
from data_gen.tts.emotion.params_model import * |
|
from data_gen.tts.emotion.params_data import * |
|
from torch.nn.utils import clip_grad_norm_ |
|
from scipy.optimize import brentq |
|
from torch import nn |
|
import numpy as np |
|
import torch |
|
|
|
|
|
class EmotionEncoder(nn.Module): |
|
def __init__(self, device, loss_device): |
|
super().__init__() |
|
self.loss_device = loss_device |
|
|
|
|
|
self.lstm = nn.LSTM(input_size=mel_n_channels, |
|
hidden_size=model_hidden_size, |
|
num_layers=model_num_layers, |
|
batch_first=True).to(device) |
|
self.linear = nn.Linear(in_features=model_hidden_size, |
|
out_features=model_embedding_size).to(device) |
|
self.relu = torch.nn.ReLU().to(device) |
|
|
|
|
|
|
|
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) |
|
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) |
|
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss().to(loss_device) |
|
|
|
def do_gradient_ops(self): |
|
|
|
self.similarity_weight.grad *= 0.01 |
|
self.similarity_bias.grad *= 0.01 |
|
|
|
|
|
clip_grad_norm_(self.parameters(), 3, norm_type=2) |
|
|
|
def forward(self, utterances, hidden_init=None): |
|
""" |
|
Computes the embeddings of a batch of utterance spectrograms. |
|
|
|
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape |
|
(batch_size, n_frames, n_channels) |
|
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, |
|
batch_size, hidden_size). Will default to a tensor of zeros if None. |
|
:return: the embeddings as a tensor of shape (batch_size, embedding_size) |
|
""" |
|
|
|
|
|
out, (hidden, cell) = self.lstm(utterances, hidden_init) |
|
|
|
|
|
embeds_raw = self.relu(self.linear(hidden[-1])) |
|
|
|
|
|
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) |
|
|
|
return embeds |
|
|
|
def inference(self, utterances, hidden_init=None): |
|
""" |
|
Computes the embeddings of a batch of utterance spectrograms. |
|
|
|
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape |
|
(batch_size, n_frames, n_channels) |
|
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, |
|
batch_size, hidden_size). Will default to a tensor of zeros if None. |
|
:return: the embeddings as a tensor of shape (batch_size, embedding_size) |
|
""" |
|
|
|
|
|
|
|
out, (hidden, cell) = self.lstm(utterances, hidden_init) |
|
|
|
return hidden[-1] |