shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
# -*- coding: utf-8 -*-
# @Time : 10/1/21
# @Author : GXYM
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
class PolyMatchingLoss(nn.Module):
def __init__(self, pnum, device, loss_type="L1"):
super(PolyMatchingLoss, self).__init__()
self.pnum = pnum
self.device = device
self.loss_type = loss_type
self.smooth_L1 = F.smooth_l1_loss
self.L2_loss = torch.nn.MSELoss(reduce=False, size_average=False)
batch_size = 1
pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
for b in range(batch_size):
for i in range(pnum):
pidx = (np.arange(pnum) + i) % pnum
pidxall[b, i] = pidx
pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach()
print(self.feature_id.shape)
def match_loss(self, pred, gt):
batch_size = pred.shape[0]
feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2)
gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, self.pnum, self.pnum, 2)
pred_expand = pred.unsqueeze(1)
if self.loss_type == "L2":
dis = self.L2_loss(pred_expand, gt_expand)
dis = dis.sum(3).sqrt().mean(2)
elif self.loss_type == "L1":
dis = self.smooth_L1(pred_expand, gt_expand, reduction='none')
dis = dis.sum(3).mean(2)
min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
return min_dis
def forward(self, pred_list, gt):
loss = torch.tensor(0.)
for pred in pred_list:
loss += torch.mean(self.match_loss(pred, gt))
return loss / torch.tensor(len(pred_list))
# los = []
# for pred in pred_list:
# los.append(self.match_loss(pred, gt))
#
# los_b = torch.tensor(0.)
# loss_c = torch.tensor(0.)
# for i, _ in enumerate(los):
# los_b += torch.mean(los[i])
# loss_c += (torch.mean(torch.clamp(los[i] - los[i - 1], min=0.0)) if i > 0 else torch.tensor(0.))
# loss = los_b / torch.tensor(len(los)) + 0.5*loss_c / torch.tensor(len(los)-1)
#
# return loss
class AttentionLoss(nn.Module):
def __init__(self, beta=4, gamma=0.5):
super(AttentionLoss, self).__init__()
self.beta = beta
self.gamma = gamma
def forward(self, pred, gt):
num_pos = torch.sum(gt)
num_neg = torch.sum(1 - gt)
alpha = num_neg / (num_pos + num_neg)
edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
loss = 0
loss = loss - alpha * edge_beta * torch.log(pred) * gt
loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
return torch.mean(loss)
class GeoCrossEntropyLoss(nn.Module):
def __init__(self):
super(GeoCrossEntropyLoss, self).__init__()
def forward(self, output, target, poly):
output = torch.nn.functional.softmax(output, dim=1)
output = torch.log(torch.clamp(output, min=1e-4))
poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3))
target_poly = torch.gather(poly, 2, target)
sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
return loss
class AELoss(nn.Module):
def __init__(self):
super(AELoss, self).__init__()
def forward(self, ae, ind, ind_mask):
"""
ae: [b, 1, h, w]
ind: [b, max_objs, max_parts]
ind_mask: [b, max_objs, max_parts]
obj_mask: [b, max_objs]
"""
# first index
b, _, h, w = ae.shape
b, max_objs, max_parts = ind.shape
obj_mask = torch.sum(ind_mask, dim=2) != 0
ae = ae.view(b, h * w, 1)
seed_ind = ind.view(b, max_objs * max_parts, 1)
tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
# compute the mean
tag_mean = tag * ind_mask
tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
# pull ae of the same object to their mean
pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
obj_num = obj_mask.sum(dim=1).float()
pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
pull /= b
# push away the mean of different objects
push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
push_dist = 1 - push_dist
push_dist = nn.functional.relu(push_dist, inplace=True)
obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
push_dist = push_dist * obj_mask.float()
push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum()
push /= b
return pull, push
def smooth_l1_loss(inputs, target, sigma=9.0):
try:
diff = torch.abs(inputs - target)
less_one = (diff < 1.0 / sigma).float()
loss = less_one * 0.5 * diff ** 2 * sigma \
+ torch.abs(torch.tensor(1.0) - less_one) * (diff - 0.5 / sigma)
loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)
except Exception as e:
print('RPN_REGR_Loss Exception:', e)
loss = torch.tensor(0.0)
return loss
def _neg_loss(pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
Arguments:
pred (batch x c x h x w)
gt_regr (batch x c x h x w)
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
class FocalLoss(nn.Module):
'''nn.Module warpper for focal loss'''
def __init__(self):
super(FocalLoss, self).__init__()
self.neg_loss = _neg_loss
def forward(self, out, target):
return self.neg_loss(out, target)