|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from utils.box_utils import match, log_sum_exp |
|
from data import cfg_mnet |
|
GPU = cfg_mnet['gpu_train'] |
|
|
|
class MultiBoxLoss(nn.Module): |
|
"""SSD Weighted Loss Function |
|
Compute Targets: |
|
1) Produce Confidence Target Indices by matching ground truth boxes |
|
with (default) 'priorboxes' that have jaccard index > threshold parameter |
|
(default threshold: 0.5). |
|
2) Produce localization target by 'encoding' variance into offsets of ground |
|
truth boxes and their matched 'priorboxes'. |
|
3) Hard negative mining to filter the excessive number of negative examples |
|
that comes with using a large number of default bounding boxes. |
|
(default negative:positive ratio 3:1) |
|
Objective Loss: |
|
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N |
|
Where, Lconf==the CrossEntropy Loss and Lloc==the SmoothL1 Loss |
|
weighted by α which==set to 1 by cross val. |
|
Args: |
|
c: class confidences, |
|
l: predicted boxes, |
|
g: ground truth boxes |
|
N: number of matched default boxes |
|
See: https://arxiv.org/pdf/1512.02325.pdf for more details. |
|
""" |
|
|
|
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): |
|
super(MultiBoxLoss, self).__init__() |
|
self.num_classes = num_classes |
|
self.threshold = overlap_thresh |
|
self.background_label = bkg_label |
|
self.encode_target = encode_target |
|
self.use_prior_for_matching = prior_for_matching |
|
self.do_neg_mining = neg_mining |
|
self.negpos_ratio = neg_pos |
|
self.neg_overlap = neg_overlap |
|
self.variance = [0.1, 0.2] |
|
|
|
def forward(self, predictions, priors, targets): |
|
"""Multibox Loss |
|
Args: |
|
predictions (tuple): A tuple containing loc preds, conf preds, |
|
and prior boxes from SSD net. |
|
conf shape: torch.size(batch_size,num_priors,num_classes) |
|
loc shape: torch.size(batch_size,num_priors,4) |
|
priors shape: torch.size(num_priors,4) |
|
|
|
ground_truth (tensor): Ground truth boxes and labels for a batch, |
|
shape: [batch_size,num_objs,5] (last idx==the label). |
|
""" |
|
|
|
loc_data, conf_data, landm_data = predictions |
|
priors = priors |
|
num = loc_data.size(0) |
|
num_priors = (priors.size(0)) |
|
|
|
|
|
loc_t = torch.Tensor(num, num_priors, 4) |
|
landm_t = torch.Tensor(num, num_priors, 10) |
|
conf_t = torch.LongTensor(num, num_priors) |
|
for idx in range(num): |
|
truths = targets[idx][:, :4].data |
|
labels = targets[idx][:, -1].data |
|
landms = targets[idx][:, 4:14].data |
|
defaults = priors.data |
|
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) |
|
if GPU: |
|
loc_t = loc_t.cuda() |
|
conf_t = conf_t.cuda() |
|
landm_t = landm_t.cuda() |
|
|
|
zeros = torch.tensor(0).cuda() |
|
|
|
|
|
pos1 = conf_t > zeros |
|
num_pos_landm = pos1.long().sum(1, keepdim=True) |
|
N1 = max(num_pos_landm.data.sum().float(), 1) |
|
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) |
|
landm_p = landm_data[pos_idx1].view(-1, 10) |
|
landm_t = landm_t[pos_idx1].view(-1, 10) |
|
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') |
|
|
|
|
|
pos = conf_t != zeros |
|
conf_t[pos] = 1 |
|
|
|
|
|
|
|
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) |
|
loc_p = loc_data[pos_idx].view(-1, 4) |
|
loc_t = loc_t[pos_idx].view(-1, 4) |
|
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') |
|
|
|
|
|
batch_conf = conf_data.view(-1, self.num_classes) |
|
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) |
|
|
|
|
|
loss_c[pos.view(-1, 1)] = 0 |
|
loss_c = loss_c.view(num, -1) |
|
_, loss_idx = loss_c.sort(1, descending=True) |
|
_, idx_rank = loss_idx.sort(1) |
|
num_pos = pos.long().sum(1, keepdim=True) |
|
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) |
|
neg = idx_rank < num_neg.expand_as(idx_rank) |
|
|
|
|
|
pos_idx = pos.unsqueeze(2).expand_as(conf_data) |
|
neg_idx = neg.unsqueeze(2).expand_as(conf_data) |
|
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) |
|
targets_weighted = conf_t[(pos+neg).gt(0)] |
|
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') |
|
|
|
|
|
N = max(num_pos.data.sum().float(), 1) |
|
loss_l /= N |
|
loss_c /= N |
|
loss_landm /= N1 |
|
|
|
return loss_l, loss_c, loss_landm |
|
|