File size: 1,895 Bytes
3be9ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math


class ArcNet(nn.Module):
    def __init__(self,
                 feature_dim,
                 class_dim,
                 margin=0.2,
                 scale=30.0,
                 easy_margin=False):
        super().__init__()
        self.feature_dim = feature_dim
        self.class_dim = class_dim
        self.margin = margin
        self.scale = scale
        self.easy_margin = easy_margin
        self.weight = Parameter(torch.FloatTensor(feature_dim, class_dim))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label):
        input_norm = torch.sqrt(torch.sum(torch.square(input), dim=1, keepdim=True))
        input = torch.divide(input, input_norm)

        weight_norm = torch.sqrt(torch.sum(torch.square(self.weight), dim=0, keepdim=True))
        weight = torch.divide(self.weight, weight_norm)

        cos = torch.matmul(input, weight)
        sin = torch.sqrt(1.0 - torch.square(cos) + 1e-6)
        cos_m = math.cos(self.margin)
        sin_m = math.sin(self.margin)
        phi = cos * cos_m - sin * sin_m

        th = math.cos(self.margin) * (-1)
        mm = math.sin(self.margin) * self.margin
        if self.easy_margin:
            phi = self._paddle_where_more_than(cos, 0, phi, cos)
        else:
            phi = self._paddle_where_more_than(cos, th, phi, cos - mm)
        one_hot = torch.nn.functional.one_hot(label, self.class_dim)
        one_hot = torch.squeeze(one_hot, dim=1)
        output = torch.multiply(one_hot, phi) + torch.multiply((1.0 - one_hot), cos)
        output = output * self.scale
        return output

    def _paddle_where_more_than(self, target, limit, x, y):
        mask = (target > limit).float()
        output = torch.multiply(mask, x) + torch.multiply((1.0 - mask), y)
        return output