lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
Speaker verification loss
Authors:
* Haibin Wu 2022
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"softmax",
"amsoftmax",
]
class softmax(nn.Module):
"""
The standard softmax loss in an unified interface for all speaker-related softmax losses
"""
def __init__(self, input_size: int, output_size: int):
super().__init__()
self._indim = input_size
self._outdim = output_size
self.fc = nn.Linear(input_size, output_size)
self.criertion = nn.CrossEntropyLoss()
@property
def input_size(self):
return self._indim
@property
def output_size(self):
return self._outdim
def forward(self, x: torch.Tensor, label: torch.LongTensor):
"""
Args:
x (torch.Tensor): (batch_size, input_size)
label (torch.LongTensor): (batch_size, )
Returns:
loss (torch.float)
logit (torch.Tensor): (batch_size, )
"""
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.input_size
x = F.normalize(x, dim=1)
x = self.fc(x)
loss = self.criertion(x, label)
return loss, x
class amsoftmax(nn.Module):
"""
AMSoftmax
Args:
input_size (int): The input feature size
output_size (int): The output feature size
margin (float): Hyperparameter denotes the margin to the decision boundry
scale (float): Hyperparameter that scales the cosine value
"""
def __init__(
self, input_size: int, output_size: int, margin: float = 0.2, scale: float = 30
):
super().__init__()
self._indim = input_size
self._outdim = output_size
self.margin = margin
self.scale = scale
self.W = torch.nn.Parameter(
torch.randn(input_size, output_size), requires_grad=True
)
self.ce = nn.CrossEntropyLoss()
nn.init.xavier_normal_(self.W, gain=1)
@property
def input_size(self):
return self._indim
@property
def output_size(self):
return self._outdim
def forward(self, x: torch.Tensor, label: torch.LongTensor):
"""
Args:
x (torch.Tensor): (batch_size, input_size)
label (torch.LongTensor): (batch_size, )
Returns:
loss (torch.float)
logit (torch.Tensor): (batch_size, )
"""
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.input_size
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
x_norm = torch.div(x, x_norm)
w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
w_norm = torch.div(self.W, w_norm)
costh = torch.mm(x_norm, w_norm)
label_view = label.view(-1, 1)
if label_view.is_cuda:
label_view = label_view.cpu()
delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.margin)
if x.is_cuda:
delt_costh = delt_costh.cuda()
costh_m = costh - delt_costh
costh_m_s = self.scale * costh_m
loss = self.ce(costh_m_s, label)
return loss, costh_m_s