Spaces:
Runtime error
Runtime error
File size: 5,535 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from .utils import convert_to_one_hot, weight_reduce_loss
def asymmetric_loss(pred,
target,
weight=None,
gamma_pos=1.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
avg_factor=None,
use_sigmoid=True,
eps=1e-8):
r"""asymmetric loss.
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`__ for
details.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
target (torch.Tensor): The ground truth label of the prediction with
shape (N, \*).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, ). Defaults to None.
gamma_pos (float): positive focusing parameter. Defaults to 0.0.
gamma_neg (float): Negative focusing parameter. We usually set
gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
use_sigmoid (bool): Whether the prediction uses sigmoid instead
of softmax. Defaults to True.
eps (float): The minimum value of the argument of logarithm. Defaults
to 1e-8.
Returns:
torch.Tensor: Loss.
"""
assert pred.shape == \
target.shape, 'pred and target should be in the same shape.'
if use_sigmoid:
pred_sigmoid = pred.sigmoid()
else:
pred_sigmoid = nn.functional.softmax(pred, dim=-1)
target = target.type_as(pred)
if clip and clip > 0:
pt = (1 - pred_sigmoid +
clip).clamp(max=1) * (1 - target) + pred_sigmoid * target
else:
pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target
asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg *
(1 - target))
loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@MODELS.register_module()
class AsymmetricLoss(nn.Module):
"""asymmetric loss.
Args:
gamma_pos (float): positive focusing parameter.
Defaults to 0.0.
gamma_neg (float): Negative focusing parameter. We
usually set gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str): The method used to reduce the loss into
a scalar.
loss_weight (float): Weight of loss. Defaults to 1.0.
use_sigmoid (bool): Whether the prediction uses sigmoid instead
of softmax. Defaults to True.
eps (float): The minimum value of the argument of logarithm. Defaults
to 1e-8.
"""
def __init__(self,
gamma_pos=0.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
loss_weight=1.0,
use_sigmoid=True,
eps=1e-8):
super(AsymmetricLoss, self).__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.reduction = reduction
self.loss_weight = loss_weight
self.use_sigmoid = use_sigmoid
self.eps = eps
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
r"""asymmetric loss.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
target (torch.Tensor): The ground truth label of the prediction
with shape (N, \*), N or (N,1).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, \*). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss into a scalar. Options are "none", "mean" and "sum".
Defaults to None.
Returns:
torch.Tensor: Loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1):
target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1])
loss_cls = self.loss_weight * asymmetric_loss(
pred,
target,
weight,
gamma_pos=self.gamma_pos,
gamma_neg=self.gamma_neg,
clip=self.clip,
reduction=reduction,
avg_factor=avg_factor,
use_sigmoid=self.use_sigmoid,
eps=self.eps)
return loss_cls
|