|
import mmcv |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ..builder import LOSSES |
|
from .utils import weight_reduce_loss |
|
|
|
|
|
@mmcv.jit(derivate=True, coderize=True) |
|
def varifocal_loss(pred, |
|
target, |
|
weight=None, |
|
alpha=0.75, |
|
gamma=2.0, |
|
iou_weighted=True, |
|
reduction='mean', |
|
avg_factor=None): |
|
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction with shape (N, C), C is the |
|
number of classes |
|
target (torch.Tensor): The learning target of the iou-aware |
|
classification score with shape (N, C), C is the number of classes. |
|
weight (torch.Tensor, optional): The weight of loss for each |
|
prediction. Defaults to None. |
|
alpha (float, optional): A balance factor for the negative part of |
|
Varifocal Loss, which is different from the alpha of Focal Loss. |
|
Defaults to 0.75. |
|
gamma (float, optional): The gamma for calculating the modulating |
|
factor. Defaults to 2.0. |
|
iou_weighted (bool, optional): Whether to weight the loss of the |
|
positive example with the iou target. Defaults to True. |
|
reduction (str, optional): The method used to reduce the loss into |
|
a scalar. Defaults to 'mean'. Options are "none", "mean" and |
|
"sum". |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
""" |
|
|
|
assert pred.size() == target.size() |
|
pred_sigmoid = pred.sigmoid() |
|
target = target.type_as(pred) |
|
if iou_weighted: |
|
focal_weight = target * (target > 0.0).float() + \ |
|
alpha * (pred_sigmoid - target).abs().pow(gamma) * \ |
|
(target <= 0.0).float() |
|
else: |
|
focal_weight = (target > 0.0).float() + \ |
|
alpha * (pred_sigmoid - target).abs().pow(gamma) * \ |
|
(target <= 0.0).float() |
|
loss = F.binary_cross_entropy_with_logits( |
|
pred, target, reduction='none') * focal_weight |
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
|
return loss |
|
|
|
|
|
@LOSSES.register_module() |
|
class VarifocalLoss(nn.Module): |
|
|
|
def __init__(self, |
|
use_sigmoid=True, |
|
alpha=0.75, |
|
gamma=2.0, |
|
iou_weighted=True, |
|
reduction='mean', |
|
loss_weight=1.0): |
|
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ |
|
|
|
Args: |
|
use_sigmoid (bool, optional): Whether the prediction is |
|
used for sigmoid or softmax. Defaults to True. |
|
alpha (float, optional): A balance factor for the negative part of |
|
Varifocal Loss, which is different from the alpha of Focal |
|
Loss. Defaults to 0.75. |
|
gamma (float, optional): The gamma for calculating the modulating |
|
factor. Defaults to 2.0. |
|
iou_weighted (bool, optional): Whether to weight the loss of the |
|
positive examples with the iou target. Defaults to True. |
|
reduction (str, optional): The method used to reduce the loss into |
|
a scalar. Defaults to 'mean'. Options are "none", "mean" and |
|
"sum". |
|
loss_weight (float, optional): Weight of loss. Defaults to 1.0. |
|
""" |
|
super(VarifocalLoss, self).__init__() |
|
assert use_sigmoid is True, \ |
|
'Only sigmoid varifocal loss supported now.' |
|
assert alpha >= 0.0 |
|
self.use_sigmoid = use_sigmoid |
|
self.alpha = alpha |
|
self.gamma = gamma |
|
self.iou_weighted = iou_weighted |
|
self.reduction = reduction |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, |
|
pred, |
|
target, |
|
weight=None, |
|
avg_factor=None, |
|
reduction_override=None): |
|
"""Forward function. |
|
|
|
Args: |
|
pred (torch.Tensor): The prediction. |
|
target (torch.Tensor): The learning target of the prediction. |
|
weight (torch.Tensor, optional): The weight of loss for each |
|
prediction. Defaults to None. |
|
avg_factor (int, optional): Average factor that is used to average |
|
the loss. Defaults to None. |
|
reduction_override (str, optional): The reduction method used to |
|
override the original reduction method of the loss. |
|
Options are "none", "mean" and "sum". |
|
|
|
Returns: |
|
torch.Tensor: The calculated loss |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
if self.use_sigmoid: |
|
loss_cls = self.loss_weight * varifocal_loss( |
|
pred, |
|
target, |
|
weight, |
|
alpha=self.alpha, |
|
gamma=self.gamma, |
|
iou_weighted=self.iou_weighted, |
|
reduction=reduction, |
|
avg_factor=avg_factor) |
|
else: |
|
raise NotImplementedError |
|
return loss_cls |
|
|