Spaces:
Build error
Build error
""" | |
@Date: 2021/08/12 | |
@description: | |
""" | |
import torch | |
import torch.nn as nn | |
from loss.grad_loss import GradLoss | |
class ObjectLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.heat_map_loss = HeatmapLoss(reduction='mean') # FocalLoss(reduction='mean') | |
self.l1_loss = nn.SmoothL1Loss() | |
def forward(self, gt, dt): | |
# TODO:: | |
return 0 | |
class HeatmapLoss(nn.Module): | |
def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'): | |
super(HeatmapLoss, self).__init__() | |
self.alpha = alpha | |
self.beta = beta | |
self.reduction = reduction | |
def forward(self, targets, inputs): | |
center_id = (targets == 1.0).float() | |
other_id = (targets != 1.0).float() | |
center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14) | |
other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14) | |
loss = center_loss + other_loss | |
batch_size = loss.size(0) | |
if self.reduction == 'mean': | |
loss = torch.sum(loss) / batch_size | |
if self.reduction == 'sum': | |
loss = torch.sum(loss) / batch_size | |
return loss | |