|
import torch as t |
|
import torch.nn as nn |
|
|
|
|
|
|
|
class SaliencyLoss(nn.Module): |
|
def __init__(self): |
|
super(SaliencyLoss, self).__init__() |
|
|
|
def forward(self, preds, labels, loss_type='cc'): |
|
losses = [] |
|
if loss_type == 'cc': |
|
for i in range(labels.shape[0]): |
|
loss = loss_CC(preds[i],labels[i]) |
|
losses.append(loss) |
|
|
|
elif loss_type == 'kldiv': |
|
for i in range(labels.shape[0]): |
|
loss = loss_KLdiv(preds[i],labels[i]) |
|
losses.append(loss) |
|
|
|
elif loss_type == 'sim': |
|
for i in range(labels.shape[0]): |
|
loss = loss_similarity(preds[i],labels[i]) |
|
losses.append(loss) |
|
|
|
elif loss_type == 'nss': |
|
for i in range(labels.shape[0]): |
|
loss = loss_NSS(preds[i],labels[i]) |
|
losses.append(loss) |
|
|
|
return t.stack(losses).mean(dim=0, keepdim=True) |
|
|
|
|
|
def loss_KLdiv(pred_map, gt_map): |
|
eps = 2.2204e-16 |
|
pred_map = pred_map/t.sum(pred_map) |
|
gt_map = gt_map/t.sum(gt_map) |
|
div = t.sum(t.mul(gt_map, t.log(eps + t.div(gt_map,pred_map+eps)))) |
|
return div |
|
|
|
|
|
def loss_CC(pred_map,gt_map): |
|
gt_map_ = (gt_map - t.mean(gt_map)) |
|
pred_map_ = (pred_map - t.mean(pred_map)) |
|
cc = t.sum(t.mul(gt_map_,pred_map_))/t.sqrt(t.sum(t.mul(gt_map_,gt_map_))*t.sum(t.mul(pred_map_,pred_map_))) |
|
return cc |
|
|
|
|
|
def loss_similarity(pred_map,gt_map): |
|
gt_map = (gt_map - t.min(gt_map))/(t.max(gt_map)-t.min(gt_map)) |
|
gt_map = gt_map/t.sum(gt_map) |
|
|
|
pred_map = (pred_map - t.min(pred_map))/(t.max(pred_map)-t.min(pred_map)) |
|
pred_map = pred_map/t.sum(pred_map) |
|
|
|
diff = t.min(gt_map,pred_map) |
|
score = t.sum(diff) |
|
|
|
return score |
|
|
|
|
|
def loss_NSS(pred_map,fix_map): |
|
'''ground truth here is fixation map''' |
|
|
|
pred_map_ = (pred_map - t.mean(pred_map))/t.std(pred_map) |
|
mask = fix_map.gt(0) |
|
score = t.mean(t.masked_select(pred_map_, mask)) |
|
return score |
|
|