File size: 3,252 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Speaker verification loss

Authors:
  * Haibin Wu 2022
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    "softmax",
    "amsoftmax",
]


class softmax(nn.Module):
    """
    The standard softmax loss in an unified interface for all speaker-related softmax losses
    """

    def __init__(self, input_size: int, output_size: int):
        super().__init__()
        self._indim = input_size
        self._outdim = output_size

        self.fc = nn.Linear(input_size, output_size)
        self.criertion = nn.CrossEntropyLoss()

    @property
    def input_size(self):
        return self._indim

    @property
    def output_size(self):
        return self._outdim

    def forward(self, x: torch.Tensor, label: torch.LongTensor):
        """
        Args:
            x (torch.Tensor): (batch_size, input_size)
            label (torch.LongTensor): (batch_size, )

        Returns:
            loss (torch.float)
            logit (torch.Tensor): (batch_size, )
        """

        assert x.size()[0] == label.size()[0]
        assert x.size()[1] == self.input_size

        x = F.normalize(x, dim=1)
        x = self.fc(x)
        loss = self.criertion(x, label)

        return loss, x


class amsoftmax(nn.Module):
    """
    AMSoftmax

    Args:
        input_size (int): The input feature size
        output_size (int): The output feature size
        margin (float): Hyperparameter denotes the margin to the decision boundry
        scale (float): Hyperparameter that scales the cosine value
    """

    def __init__(
        self, input_size: int, output_size: int, margin: float = 0.2, scale: float = 30
    ):
        super().__init__()
        self._indim = input_size
        self._outdim = output_size
        self.margin = margin
        self.scale = scale

        self.W = torch.nn.Parameter(
            torch.randn(input_size, output_size), requires_grad=True
        )
        self.ce = nn.CrossEntropyLoss()
        nn.init.xavier_normal_(self.W, gain=1)

    @property
    def input_size(self):
        return self._indim

    @property
    def output_size(self):
        return self._outdim

    def forward(self, x: torch.Tensor, label: torch.LongTensor):
        """
        Args:
            x (torch.Tensor): (batch_size, input_size)
            label (torch.LongTensor): (batch_size, )

        Returns:
            loss (torch.float)
            logit (torch.Tensor): (batch_size, )
        """

        assert x.size()[0] == label.size()[0]
        assert x.size()[1] == self.input_size

        x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        x_norm = torch.div(x, x_norm)
        w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
        w_norm = torch.div(self.W, w_norm)
        costh = torch.mm(x_norm, w_norm)
        label_view = label.view(-1, 1)
        if label_view.is_cuda:
            label_view = label_view.cpu()
        delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.margin)
        if x.is_cuda:
            delt_costh = delt_costh.cuda()
        costh_m = costh - delt_costh
        costh_m_s = self.scale * costh_m
        loss = self.ce(costh_m_s, label)

        return loss, costh_m_s