lmzjms's picture
Upload 1162 files
0b32ad6 verified
import torch
import torch.nn as nn
class SelfAttentionPooling(nn.Module):
"""
Implementation of SelfAttentionPooling
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
https://arxiv.org/pdf/2008.01077v1.pdf
"""
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 1)
def forward(self, batch_rep):
"""
input:
batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
attention_weight:
att_w : size (N, T, 1)
return:
utter_rep: size (N, H)
"""
softmax = nn.functional.softmax
att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
return utter_rep
class Model(nn.Module):
def __init__(self, input_dim, clipping=False, attention_pooling=False, num_judges=5000, **kwargs):
super(Model, self).__init__()
self.mean_net_linear = nn.Linear(input_dim, 1)
self.mean_net_clipping = clipping
self.mean_net_pooling = SelfAttentionPooling(input_dim) if attention_pooling else None
self.bias_net_linear = nn.Linear(input_dim, 1)
self.bias_net_pooling = SelfAttentionPooling(input_dim) if attention_pooling else None
self.judge_embbeding = nn.Embedding(num_embeddings = num_judges, embedding_dim=input_dim)
def forward(self, features, judge_ids=None):
if self.mean_net_pooling is not None:
x = self.mean_net_pooling(features)
segment_score = self.mean_net_linear(x)
else:
x = self.mean_net_linear(features)
segment_score = x.squeeze(-1).mean(dim=-1)
if self.mean_net_clipping:
segment_score = torch.tanh(segment_score) * 2 + 3
if judge_ids is None:
return segment_score.squeeze(-1)
else:
time = features.shape[1]
judge_features = self.judge_embbeding(judge_ids)
judge_features = torch.stack([judge_features for i in range(time)], dim = 1)
bias_features = features + judge_features
if self.bias_net_pooling is not None:
y = self.bias_net_pooling(bias_features)
bias_score = self.bias_net_linear(y)
else:
y = self.bias_net_linear(bias_features)
bias_score = y.squeeze(-1).mean(dim=-1)
bias_score = bias_score + segment_score
return segment_score.squeeze(-1), bias_score.squeeze(-1)