import torch import numpy as np from torch import nn from scipy.optimize import brentq from sklearn.metrics import roc_curve from scipy.interpolate import interp1d from torch.nn.parameter import Parameter from torch.nn.utils.clip_grad import clip_grad_norm_ from .hparams import hparams as hp class SpeakerEncoder(nn.Module): def __init__(self, device, loss_device): super().__init__() self.loss_device = loss_device # Network defition self.lstm = nn.LSTM( input_size=hp.mel_n_channels, hidden_size=hp.model_hidden_size, num_layers=hp.model_num_layers, batch_first=True, ).to(device) self.linear = nn.Linear( in_features=hp.model_hidden_size, out_features=hp.model_embedding_size ).to(device) self.relu = torch.nn.ReLU().to(device) # Cosine similarity scaling (with fixed initial parameter values) self.similarity_weight = Parameter(torch.tensor([10.0])).to(loss_device) self.similarity_bias = Parameter(torch.tensor([-5.0])).to(loss_device) # Loss self.loss_fn = nn.CrossEntropyLoss().to(loss_device) def do_gradient_ops(self): # Gradient scale self.similarity_weight.grad *= 0.01 self.similarity_bias.grad *= 0.01 # Gradient clipping clip_grad_norm_(self.parameters(), 3, norm_type=2) def forward(self, utterances, hidden_init=None): """ Computes the embeddings of a batch of utterance spectrograms. :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape (batch_size, n_frames, n_channels) :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. :return: the embeddings as a tensor of shape (batch_size, embedding_size) """ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state # and the final cell state. out, (hidden, cell) = self.lstm(utterances, hidden_init) # We take only the hidden state of the last layer embeds_raw = self.relu(self.linear(hidden[-1])) # L2-normalize it embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5) return embeds def similarity_matrix(self, embeds): """ Computes the similarity matrix according the section 2.1 of GE2E. :param embeds: the embeddings as a tensor of shape (speakers_per_batch, utterances_per_speaker, embedding_size) :return: the similarity matrix as a tensor of shape (speakers_per_batch, utterances_per_speaker, speakers_per_batch) """ speakers_per_batch, utterances_per_speaker = embeds.shape[:2] # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation centroids_incl = torch.mean(embeds, dim=1, keepdim=True) centroids_incl = centroids_incl.clone() / ( torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5 ) # Exclusive centroids (1 per utterance) centroids_excl = torch.sum(embeds, dim=1, keepdim=True) - embeds centroids_excl /= utterances_per_speaker - 1 centroids_excl = centroids_excl.clone() / ( torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5 ) # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot # product of these vectors (which is just an element-wise multiplication reduced by a sum). # We vectorize the computation for efficiency. sim_matrix = torch.zeros( speakers_per_batch, utterances_per_speaker, speakers_per_batch ).to(self.loss_device) mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int32) for j in range(speakers_per_batch): mask = np.where(mask_matrix[j])[0] sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) ## Even more vectorized version (slower maybe because of transpose) # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker # ).to(self.loss_device) # eye = np.eye(speakers_per_batch, dtype=np.int) # mask = np.where(1 - eye) # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) # mask = np.where(eye) # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) # sim_matrix2 = sim_matrix2.transpose(1, 2) sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias return sim_matrix def loss(self, embeds): """ Computes the softmax loss according the section 2.1 of GE2E. :param embeds: the embeddings as a tensor of shape (speakers_per_batch, utterances_per_speaker, embedding_size) :return: the loss and the EER for this batch of embeddings. """ speakers_per_batch, utterances_per_speaker = embeds.shape[:2] # Loss sim_matrix = self.similarity_matrix(embeds) sim_matrix = sim_matrix.reshape( (speakers_per_batch * utterances_per_speaker, speakers_per_batch) ) ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) target = torch.from_numpy(ground_truth).long().to(self.loss_device) loss = self.loss_fn(sim_matrix, target) # EER (not backpropagated) with torch.no_grad(): inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int32)[0] labels = np.array([inv_argmax(i) for i in ground_truth]) preds = sim_matrix.detach().cpu().numpy() # Snippet from https://yangcha.github.io/EER-ROC/ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) return loss, eer