Spaces:
Runtime error
Runtime error
""" | |
Author: Siyuan Li | |
Licensed: Apache-2.0 License | |
""" | |
import torch | |
import torch.nn as nn | |
from mmdet.models import weight_reduce_loss | |
from mmdet.registry import MODELS | |
def multi_pos_cross_entropy( | |
pred, label, weight=None, reduction="mean", avg_factor=None, pos_normalize=True, | |
): | |
valid_mask = label.sum(1) != 0 | |
pred = pred[valid_mask] | |
label = label[valid_mask] | |
weight = weight[valid_mask] | |
if min(pred.shape) != 0: | |
logits_max, _ = torch.max(pred, dim=1, keepdim=True) | |
logits = pred - logits_max.detach() | |
else: | |
logits = pred | |
if pos_normalize: | |
pos_norm = torch.div(label, label.sum(1).reshape(-1, 1)) | |
exp_logits = (torch.exp(logits)) * pos_norm + ( | |
torch.exp(logits) | |
) * torch.logical_not(label) | |
else: | |
exp_logits = torch.exp(logits) | |
exp_logits_input = exp_logits.sum(1, keepdim=True) | |
log_prob = logits - torch.log(exp_logits_input) | |
mean_log_prob_pos = (label * log_prob).sum(1) / label.sum(1) | |
loss = -mean_log_prob_pos | |
# apply weights and do the reduction | |
if weight is not None: | |
weight = weight.float() | |
loss = weight_reduce_loss( | |
loss, weight=weight, reduction=reduction, avg_factor=avg_factor | |
) | |
return loss | |
class UnbiasedContrastLoss(nn.Module): | |
def __init__(self, reduction="mean", loss_weight=1.0): | |
super(UnbiasedContrastLoss, self).__init__() | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
def forward( | |
self, | |
cls_score, | |
label, | |
weight=None, | |
avg_factor=None, | |
reduction_override=None, | |
**kwargs | |
): | |
assert cls_score.size() == label.size() | |
assert reduction_override in (None, "none", "mean", "sum") | |
reduction = reduction_override if reduction_override else self.reduction | |
loss_cls = self.loss_weight * multi_pos_cross_entropy( | |
cls_score, | |
label, | |
weight, | |
reduction=reduction, | |
avg_factor=avg_factor, | |
**kwargs | |
) | |
return loss_cls | |