import torch import torch.nn as nn class SelfAttentionPooling(nn.Module): """ Implementation of SelfAttentionPooling Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition https://arxiv.org/pdf/2008.01077v1.pdf """ def __init__(self, input_dim): super(SelfAttentionPooling, self).__init__() self.W = nn.Linear(input_dim, 1) def forward(self, batch_rep): """ input: batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension attention_weight: att_w : size (N, T, 1) return: utter_rep: size (N, H) """ softmax = nn.functional.softmax att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1) utter_rep = torch.sum(batch_rep * att_w, dim=1) return utter_rep class Model(nn.Module): def __init__(self, input_dim, clipping=False, attention_pooling=False, num_judges=5000, **kwargs): super(Model, self).__init__() self.mean_net_linear = nn.Linear(input_dim, 1) self.mean_net_clipping = clipping self.mean_net_pooling = SelfAttentionPooling(input_dim) if attention_pooling else None self.bias_net_linear = nn.Linear(input_dim, 1) self.bias_net_pooling = SelfAttentionPooling(input_dim) if attention_pooling else None self.judge_embbeding = nn.Embedding(num_embeddings = num_judges, embedding_dim=input_dim) def forward(self, features, judge_ids=None): if self.mean_net_pooling is not None: x = self.mean_net_pooling(features) segment_score = self.mean_net_linear(x) else: x = self.mean_net_linear(features) segment_score = x.squeeze(-1).mean(dim=-1) if self.mean_net_clipping: segment_score = torch.tanh(segment_score) * 2 + 3 if judge_ids is None: return segment_score.squeeze(-1) else: time = features.shape[1] judge_features = self.judge_embbeding(judge_ids) judge_features = torch.stack([judge_features for i in range(time)], dim = 1) bias_features = features + judge_features if self.bias_net_pooling is not None: y = self.bias_net_pooling(bias_features) bias_score = self.bias_net_linear(y) else: y = self.bias_net_linear(bias_features) bias_score = y.squeeze(-1).mean(dim=-1) bias_score = bias_score + segment_score return segment_score.squeeze(-1), bias_score.squeeze(-1)