|
"""
|
|
Copyright (c) 2018 Intel Corporation
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
import torch as th
|
|
|
|
from .abstract_loss_func import AbstractLossClass
|
|
from metrics.registry import LOSSFUNC
|
|
|
|
|
|
|
|
|
|
def focal_loss(input_values, gamma):
|
|
"""Computes the focal loss"""
|
|
p = torch.exp(-input_values)
|
|
loss = (1 - p) ** gamma * input_values
|
|
return loss.mean()
|
|
|
|
|
|
@LOSSFUNC.register_module(module_name="am_softmax")
|
|
class AMSoftmaxLoss(AbstractLossClass):
|
|
"""Computes the AM-Softmax loss with cos or arc margin"""
|
|
margin_types = ['cos', 'arc']
|
|
|
|
def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1.):
|
|
super().__init__()
|
|
assert margin_type in AMSoftmaxLoss.margin_types
|
|
self.margin_type = margin_type
|
|
assert gamma >= 0
|
|
self.gamma = gamma
|
|
assert m > 0
|
|
self.m = m
|
|
assert s > 0
|
|
self.s = s
|
|
self.cos_m = math.cos(m)
|
|
self.sin_m = math.sin(m)
|
|
self.th = math.cos(math.pi - m)
|
|
assert t >= 1
|
|
self.t = t
|
|
|
|
def forward(self, cos_theta, target):
|
|
if self.margin_type == 'cos':
|
|
phi_theta = cos_theta - self.m
|
|
else:
|
|
sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
|
phi_theta = cos_theta * self.cos_m - sine * self.sin_m
|
|
phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m)
|
|
|
|
index = torch.zeros_like(cos_theta, dtype=torch.uint8)
|
|
index.scatter_(1, target.data.view(-1, 1), 1)
|
|
output = torch.where(index, phi_theta, cos_theta)
|
|
|
|
if self.gamma == 0 and self.t == 1.:
|
|
return F.cross_entropy(self.s*output, target)
|
|
|
|
if self.t > 1:
|
|
h_theta = self.t - 1 + self.t*cos_theta
|
|
support_vecs_mask = (1 - index) * \
|
|
torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0)
|
|
output = torch.where(support_vecs_mask, h_theta, output)
|
|
return F.cross_entropy(self.s*output, target)
|
|
|
|
return focal_loss(F.cross_entropy(self.s*output, target, reduction='none'), self.gamma)
|
|
|
|
|
|
@LOSSFUNC.register_module(module_name="am_softmax_ohem")
|
|
class AMSoftmax_OHEM(AbstractLossClass):
|
|
"""Computes the AM-Softmax loss with cos or arc margin"""
|
|
margin_types = ['cos', 'arc']
|
|
|
|
def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1., ratio=1.):
|
|
super(self).__init__()
|
|
assert margin_type in AMSoftmaxLoss.margin_types
|
|
self.margin_type = margin_type
|
|
assert gamma >= 0
|
|
self.gamma = gamma
|
|
assert m > 0
|
|
self.m = m
|
|
assert s > 0
|
|
self.s = s
|
|
self.cos_m = math.cos(m)
|
|
self.sin_m = math.sin(m)
|
|
self.th = math.cos(math.pi - m)
|
|
assert t >= 1
|
|
self.t = t
|
|
self.ratio = ratio
|
|
|
|
|
|
|
|
def get_subidx(self,x,y,ratio):
|
|
num_inst = x.size(0)
|
|
num_hns = int(ratio * num_inst)
|
|
x_ = x.clone()
|
|
inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()
|
|
|
|
for idx, label in enumerate(y.data):
|
|
inst_losses[idx] = -x_.data[idx, label]
|
|
|
|
_, idxs = inst_losses.topk(num_hns)
|
|
return idxs
|
|
|
|
|
|
def forward(self, cos_theta, target):
|
|
if self.margin_type == 'cos':
|
|
phi_theta = cos_theta - self.m
|
|
else:
|
|
sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
|
phi_theta = cos_theta * self.cos_m - sine * self.sin_m
|
|
phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m)
|
|
|
|
index = torch.zeros_like(cos_theta, dtype=torch.uint8)
|
|
index.scatter_(1, target.data.view(-1, 1), 1)
|
|
output = torch.where(index, phi_theta, cos_theta)
|
|
|
|
out = F.log_softmax(output,dim=1)
|
|
idxs = self.get_subidx(out,target,self.ratio)
|
|
|
|
output2 = output.index_select(0, idxs)
|
|
target2 = target.index_select(0, idxs)
|
|
|
|
if self.gamma == 0 and self.t == 1.:
|
|
return F.cross_entropy(self.s*output2, target2)
|
|
|
|
if self.t > 1:
|
|
h_theta = self.t - 1 + self.t*cos_theta
|
|
support_vecs_mask = (1 - index) * \
|
|
torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0)
|
|
output2 = torch.where(support_vecs_mask, h_theta, output2)
|
|
return F.cross_entropy(self.s*output2, target2)
|
|
|
|
return focal_loss(F.cross_entropy(self.s*output2, target2, reduction='none'), self.gamma) |