# -*- coding: utf-8 -*- # @Time : 10/1/21 # @Author : GXYM import torch from torch import nn import numpy as np import torch.nn.functional as F class PolyMatchingLoss(nn.Module): def __init__(self, pnum, device, loss_type="L1"): super(PolyMatchingLoss, self).__init__() self.pnum = pnum self.device = device self.loss_type = loss_type self.smooth_L1 = F.smooth_l1_loss self.L2_loss = torch.nn.MSELoss(reduce=False, size_average=False) batch_size = 1 pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32) for b in range(batch_size): for i in range(pnum): pidx = (np.arange(pnum) + i) % pnum pidxall[b, i] = pidx pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device) self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach() print(self.feature_id.shape) def match_loss(self, pred, gt): batch_size = pred.shape[0] feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2) gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, self.pnum, self.pnum, 2) pred_expand = pred.unsqueeze(1) if self.loss_type == "L2": dis = self.L2_loss(pred_expand, gt_expand) dis = dis.sum(3).sqrt().mean(2) elif self.loss_type == "L1": dis = self.smooth_L1(pred_expand, gt_expand, reduction='none') dis = dis.sum(3).mean(2) min_dis, min_id = torch.min(dis, dim=1, keepdim=True) return min_dis def forward(self, pred_list, gt): loss = torch.tensor(0.) for pred in pred_list: loss += torch.mean(self.match_loss(pred, gt)) return loss / torch.tensor(len(pred_list)) # los = [] # for pred in pred_list: # los.append(self.match_loss(pred, gt)) # # los_b = torch.tensor(0.) # loss_c = torch.tensor(0.) # for i, _ in enumerate(los): # los_b += torch.mean(los[i]) # loss_c += (torch.mean(torch.clamp(los[i] - los[i - 1], min=0.0)) if i > 0 else torch.tensor(0.)) # loss = los_b / torch.tensor(len(los)) + 0.5*loss_c / torch.tensor(len(los)-1) # # return loss class AttentionLoss(nn.Module): def __init__(self, beta=4, gamma=0.5): super(AttentionLoss, self).__init__() self.beta = beta self.gamma = gamma def forward(self, pred, gt): num_pos = torch.sum(gt) num_neg = torch.sum(1 - gt) alpha = num_neg / (num_pos + num_neg) edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma)) bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma)) loss = 0 loss = loss - alpha * edge_beta * torch.log(pred) * gt loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt) return torch.mean(loss) class GeoCrossEntropyLoss(nn.Module): def __init__(self): super(GeoCrossEntropyLoss, self).__init__() def forward(self, output, target, poly): output = torch.nn.functional.softmax(output, dim=1) output = torch.log(torch.clamp(output, min=1e-4)) poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2) target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3)) target_poly = torch.gather(poly, 2, target) sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True) kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3)) loss = -(output * kernel.transpose(2, 1)).sum(1).mean() return loss class AELoss(nn.Module): def __init__(self): super(AELoss, self).__init__() def forward(self, ae, ind, ind_mask): """ ae: [b, 1, h, w] ind: [b, max_objs, max_parts] ind_mask: [b, max_objs, max_parts] obj_mask: [b, max_objs] """ # first index b, _, h, w = ae.shape b, max_objs, max_parts = ind.shape obj_mask = torch.sum(ind_mask, dim=2) != 0 ae = ae.view(b, h * w, 1) seed_ind = ind.view(b, max_objs * max_parts, 1) tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts) # compute the mean tag_mean = tag * ind_mask tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4) # pull ae of the same object to their mean pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask obj_num = obj_mask.sum(dim=1).float() pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum() pull /= b # push away the mean of different objects push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)) push_dist = 1 - push_dist push_dist = nn.functional.relu(push_dist, inplace=True) obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2 push_dist = push_dist * obj_mask.float() push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum() push /= b return pull, push def smooth_l1_loss(inputs, target, sigma=9.0): try: diff = torch.abs(inputs - target) less_one = (diff < 1.0 / sigma).float() loss = less_one * 0.5 * diff ** 2 * sigma \ + torch.abs(torch.tensor(1.0) - less_one) * (diff - 0.5 / sigma) loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) except Exception as e: print('RPN_REGR_Loss Exception:', e) loss = torch.tensor(0.0) return loss def _neg_loss(pred, gt): ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: pred (batch x c x h x w) gt_regr (batch x c x h x w) ''' pos_inds = gt.eq(1).float() neg_inds = gt.lt(1).float() neg_weights = torch.pow(1 - gt, 4) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss class FocalLoss(nn.Module): '''nn.Module warpper for focal loss''' def __init__(self): super(FocalLoss, self).__init__() self.neg_loss = _neg_loss def forward(self, out, target): return self.neg_loss(out, target)