|
|
|
import logging |
|
import math |
|
import json |
|
import numpy as np |
|
from typing import Dict, Union |
|
import torch |
|
from fvcore.nn import giou_loss, smooth_l1_loss |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import fvcore.nn.weight_init as weight_init |
|
import detectron2.utils.comm as comm |
|
from detectron2.config import configurable |
|
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple |
|
from detectron2.structures import Boxes, Instances |
|
from detectron2.utils.events import get_event_storage |
|
from detectron2.modeling.box_regression import Box2BoxTransform |
|
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers |
|
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference |
|
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats |
|
|
|
from torch.cuda.amp import autocast |
|
from ..utils import load_class_freq, get_fed_loss_inds |
|
from .zero_shot_classifier import ZeroShotClassifier |
|
|
|
__all__ = ["DeticFastRCNNOutputLayers"] |
|
|
|
|
|
class DeticFastRCNNOutputLayers(FastRCNNOutputLayers): |
|
@configurable |
|
def __init__( |
|
self, |
|
input_shape: ShapeSpec, |
|
*, |
|
mult_proposal_score=False, |
|
cls_score=None, |
|
sync_caption_batch = False, |
|
use_sigmoid_ce = False, |
|
use_fed_loss = False, |
|
ignore_zero_cats = False, |
|
fed_loss_num_cat = 50, |
|
dynamic_classifier = False, |
|
image_label_loss = '', |
|
use_zeroshot_cls = False, |
|
image_loss_weight = 0.1, |
|
with_softmax_prop = False, |
|
caption_weight = 1.0, |
|
neg_cap_weight = 1.0, |
|
add_image_box = False, |
|
debug = False, |
|
prior_prob = 0.01, |
|
cat_freq_path = '', |
|
fed_loss_freq_weight = 0.5, |
|
softmax_weak_loss = False, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
input_shape=input_shape, |
|
**kwargs, |
|
) |
|
self.mult_proposal_score = mult_proposal_score |
|
self.sync_caption_batch = sync_caption_batch |
|
self.use_sigmoid_ce = use_sigmoid_ce |
|
self.use_fed_loss = use_fed_loss |
|
self.ignore_zero_cats = ignore_zero_cats |
|
self.fed_loss_num_cat = fed_loss_num_cat |
|
self.dynamic_classifier = dynamic_classifier |
|
self.image_label_loss = image_label_loss |
|
self.use_zeroshot_cls = use_zeroshot_cls |
|
self.image_loss_weight = image_loss_weight |
|
self.with_softmax_prop = with_softmax_prop |
|
self.caption_weight = caption_weight |
|
self.neg_cap_weight = neg_cap_weight |
|
self.add_image_box = add_image_box |
|
self.softmax_weak_loss = softmax_weak_loss |
|
self.debug = debug |
|
|
|
if softmax_weak_loss: |
|
assert image_label_loss in ['max_size'] |
|
|
|
if self.use_sigmoid_ce: |
|
bias_value = -math.log((1 - prior_prob) / prior_prob) |
|
nn.init.constant_(self.cls_score.bias, bias_value) |
|
|
|
if self.use_fed_loss or self.ignore_zero_cats: |
|
freq_weight = load_class_freq(cat_freq_path, fed_loss_freq_weight) |
|
self.register_buffer('freq_weight', freq_weight) |
|
else: |
|
self.freq_weight = None |
|
|
|
if self.use_fed_loss and len(self.freq_weight) < self.num_classes: |
|
|
|
print('Extending federated loss weight') |
|
self.freq_weight = torch.cat( |
|
[self.freq_weight, |
|
self.freq_weight.new_zeros( |
|
self.num_classes - len(self.freq_weight))] |
|
) |
|
|
|
assert (not self.dynamic_classifier) or (not self.use_fed_loss) |
|
input_size = input_shape.channels * \ |
|
(input_shape.width or 1) * (input_shape.height or 1) |
|
|
|
if self.use_zeroshot_cls: |
|
del self.cls_score |
|
del self.bbox_pred |
|
assert cls_score is not None |
|
self.cls_score = cls_score |
|
self.bbox_pred = nn.Sequential( |
|
nn.Linear(input_size, input_size), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(input_size, 4) |
|
) |
|
weight_init.c2_xavier_fill(self.bbox_pred[0]) |
|
nn.init.normal_(self.bbox_pred[-1].weight, std=0.001) |
|
nn.init.constant_(self.bbox_pred[-1].bias, 0) |
|
|
|
if self.with_softmax_prop: |
|
self.prop_score = nn.Sequential( |
|
nn.Linear(input_size, input_size), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(input_size, self.num_classes + 1), |
|
) |
|
weight_init.c2_xavier_fill(self.prop_score[0]) |
|
nn.init.normal_(self.prop_score[-1].weight, mean=0, std=0.001) |
|
nn.init.constant_(self.prop_score[-1].bias, 0) |
|
|
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
ret = super().from_config(cfg, input_shape) |
|
ret.update({ |
|
'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, |
|
'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, |
|
'use_sigmoid_ce': cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, |
|
'use_fed_loss': cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, |
|
'ignore_zero_cats': cfg.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS, |
|
'fed_loss_num_cat': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT, |
|
'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, |
|
'image_label_loss': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS, |
|
'use_zeroshot_cls': cfg.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS, |
|
'image_loss_weight': cfg.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT, |
|
'with_softmax_prop': cfg.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP, |
|
'caption_weight': cfg.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT, |
|
'neg_cap_weight': cfg.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT, |
|
'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX, |
|
'debug': cfg.DEBUG or cfg.SAVE_DEBUG or cfg.IS_DEBUG, |
|
'prior_prob': cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB, |
|
'cat_freq_path': cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, |
|
'fed_loss_freq_weight': cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT, |
|
'softmax_weak_loss': cfg.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS, |
|
}) |
|
if ret['use_zeroshot_cls']: |
|
ret['cls_score'] = ZeroShotClassifier(cfg, input_shape) |
|
return ret |
|
|
|
def losses(self, predictions, proposals, \ |
|
use_advanced_loss=True, |
|
classifier_info=(None,None,None)): |
|
""" |
|
enable advanced loss |
|
""" |
|
scores, proposal_deltas = predictions |
|
gt_classes = ( |
|
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) |
|
) |
|
num_classes = self.num_classes |
|
if self.dynamic_classifier: |
|
_, cls_id_map = classifier_info[1] |
|
gt_classes = cls_id_map[gt_classes] |
|
num_classes = scores.shape[1] - 1 |
|
assert cls_id_map[self.num_classes] == num_classes |
|
_log_classification_stats(scores, gt_classes) |
|
|
|
if len(proposals): |
|
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) |
|
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" |
|
gt_boxes = cat( |
|
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], |
|
dim=0, |
|
) |
|
else: |
|
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) |
|
|
|
if self.use_sigmoid_ce: |
|
loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) |
|
else: |
|
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) |
|
return { |
|
"loss_cls": loss_cls, |
|
"loss_box_reg": self.box_reg_loss( |
|
proposal_boxes, gt_boxes, proposal_deltas, gt_classes, |
|
num_classes=num_classes) |
|
} |
|
|
|
|
|
def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): |
|
if pred_class_logits.numel() == 0: |
|
return pred_class_logits.new_zeros([1])[0] |
|
|
|
B = pred_class_logits.shape[0] |
|
C = pred_class_logits.shape[1] - 1 |
|
|
|
target = pred_class_logits.new_zeros(B, C + 1) |
|
target[range(len(gt_classes)), gt_classes] = 1 |
|
target = target[:, :C] |
|
|
|
weight = 1 |
|
|
|
if self.use_fed_loss and (self.freq_weight is not None): |
|
appeared = get_fed_loss_inds( |
|
gt_classes, |
|
num_sample_cats=self.fed_loss_num_cat, |
|
C=C, |
|
weight=self.freq_weight) |
|
appeared_mask = appeared.new_zeros(C + 1) |
|
appeared_mask[appeared] = 1 |
|
appeared_mask = appeared_mask[:C] |
|
fed_w = appeared_mask.view(1, C).expand(B, C) |
|
weight = weight * fed_w.float() |
|
if self.ignore_zero_cats and (self.freq_weight is not None): |
|
w = (self.freq_weight.view(-1) > 1e-4).float() |
|
weight = weight * w.view(1, C).expand(B, C) |
|
|
|
|
|
cls_loss = F.binary_cross_entropy_with_logits( |
|
pred_class_logits[:, :-1], target, reduction='none') |
|
loss = torch.sum(cls_loss * weight) / B |
|
return loss |
|
|
|
|
|
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): |
|
""" |
|
change _no_instance handling |
|
""" |
|
if pred_class_logits.numel() == 0: |
|
return pred_class_logits.new_zeros([1])[0] |
|
|
|
if self.ignore_zero_cats and (self.freq_weight is not None): |
|
zero_weight = torch.cat([ |
|
(self.freq_weight.view(-1) > 1e-4).float(), |
|
self.freq_weight.new_ones(1)]) |
|
loss = F.cross_entropy( |
|
pred_class_logits, gt_classes, |
|
weight=zero_weight, reduction="mean") |
|
elif self.use_fed_loss and (self.freq_weight is not None): |
|
C = pred_class_logits.shape[1] - 1 |
|
appeared = get_fed_loss_inds( |
|
gt_classes, |
|
num_sample_cats=self.fed_loss_num_cat, |
|
C=C, |
|
weight=self.freq_weight) |
|
appeared_mask = appeared.new_zeros(C + 1).float() |
|
appeared_mask[appeared] = 1. |
|
appeared_mask[C] = 1. |
|
loss = F.cross_entropy( |
|
pred_class_logits, gt_classes, |
|
weight=appeared_mask, reduction="mean") |
|
else: |
|
loss = F.cross_entropy( |
|
pred_class_logits, gt_classes, reduction="mean") |
|
return loss |
|
|
|
|
|
def box_reg_loss( |
|
self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, |
|
num_classes=-1): |
|
""" |
|
Allow custom background index |
|
""" |
|
num_classes = num_classes if num_classes > 0 else self.num_classes |
|
box_dim = proposal_boxes.shape[1] |
|
fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0] |
|
if pred_deltas.shape[1] == box_dim: |
|
fg_pred_deltas = pred_deltas[fg_inds] |
|
else: |
|
fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ |
|
fg_inds, gt_classes[fg_inds] |
|
] |
|
|
|
if self.box_reg_loss_type == "smooth_l1": |
|
gt_pred_deltas = self.box2box_transform.get_deltas( |
|
proposal_boxes[fg_inds], |
|
gt_boxes[fg_inds], |
|
) |
|
loss_box_reg = smooth_l1_loss( |
|
fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum" |
|
) |
|
elif self.box_reg_loss_type == "giou": |
|
fg_pred_boxes = self.box2box_transform.apply_deltas( |
|
fg_pred_deltas, proposal_boxes[fg_inds] |
|
) |
|
loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum") |
|
else: |
|
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") |
|
return loss_box_reg / max(gt_classes.numel(), 1.0) |
|
|
|
def inference(self, predictions, proposals): |
|
""" |
|
enable use proposal boxes |
|
""" |
|
predictions = (predictions[0], predictions[1]) |
|
boxes = self.predict_boxes(predictions, proposals) |
|
scores = self.predict_probs(predictions, proposals) |
|
if self.mult_proposal_score: |
|
proposal_scores = [p.get('objectness_logits') for p in proposals] |
|
scores = [(s * ps[:, None]) ** 0.5 \ |
|
for s, ps in zip(scores, proposal_scores)] |
|
image_shapes = [x.image_size for x in proposals] |
|
return fast_rcnn_inference( |
|
boxes, |
|
scores, |
|
image_shapes, |
|
self.test_score_thresh, |
|
self.test_nms_thresh, |
|
self.test_topk_per_image, |
|
) |
|
|
|
|
|
def predict_probs(self, predictions, proposals): |
|
""" |
|
support sigmoid |
|
""" |
|
|
|
scores = predictions[0] |
|
num_inst_per_image = [len(p) for p in proposals] |
|
if self.use_sigmoid_ce: |
|
probs = scores.sigmoid() |
|
else: |
|
probs = F.softmax(scores, dim=-1) |
|
return probs.split(num_inst_per_image, dim=0) |
|
|
|
|
|
def image_label_losses(self, predictions, proposals, image_labels, \ |
|
classifier_info=(None,None,None), ann_type='image'): |
|
''' |
|
Inputs: |
|
scores: N x (C + 1) |
|
image_labels B x 1 |
|
''' |
|
num_inst_per_image = [len(p) for p in proposals] |
|
scores = predictions[0] |
|
scores = scores.split(num_inst_per_image, dim=0) |
|
if self.with_softmax_prop: |
|
prop_scores = predictions[2].split(num_inst_per_image, dim=0) |
|
else: |
|
prop_scores = [None for _ in num_inst_per_image] |
|
B = len(scores) |
|
img_box_count = 0 |
|
select_size_count = 0 |
|
select_x_count = 0 |
|
select_y_count = 0 |
|
max_score_count = 0 |
|
storage = get_event_storage() |
|
loss = scores[0].new_zeros([1])[0] |
|
caption_loss = scores[0].new_zeros([1])[0] |
|
for idx, (score, labels, prop_score, p) in enumerate(zip( |
|
scores, image_labels, prop_scores, proposals)): |
|
if score.shape[0] == 0: |
|
loss += score.new_zeros([1])[0] |
|
continue |
|
if 'caption' in ann_type: |
|
score, caption_loss_img = self._caption_loss( |
|
score, classifier_info, idx, B) |
|
caption_loss += self.caption_weight * caption_loss_img |
|
if ann_type == 'caption': |
|
continue |
|
|
|
if self.debug: |
|
p.selected = score.new_zeros( |
|
(len(p),), dtype=torch.long) - 1 |
|
for i_l, label in enumerate(labels): |
|
if self.dynamic_classifier: |
|
if idx == 0 and i_l == 0 and comm.is_main_process(): |
|
storage.put_scalar('stats_label', label) |
|
label = classifier_info[1][1][label] |
|
assert label < score.shape[1] |
|
if self.image_label_loss in ['wsod', 'wsddn']: |
|
loss_i, ind = self._wsddn_loss(score, prop_score, label) |
|
elif self.image_label_loss == 'max_score': |
|
loss_i, ind = self._max_score_loss(score, label) |
|
elif self.image_label_loss == 'max_size': |
|
loss_i, ind = self._max_size_loss(score, label, p) |
|
elif self.image_label_loss == 'first': |
|
loss_i, ind = self._first_loss(score, label) |
|
elif self.image_label_loss == 'image': |
|
loss_i, ind = self._image_loss(score, label) |
|
elif self.image_label_loss == 'min_loss': |
|
loss_i, ind = self._min_loss_loss(score, label) |
|
else: |
|
assert 0 |
|
loss += loss_i / len(labels) |
|
if type(ind) == type([]): |
|
img_box_count = sum(ind) / len(ind) |
|
if self.debug: |
|
for ind_i in ind: |
|
p.selected[ind_i] = label |
|
else: |
|
img_box_count = ind |
|
select_size_count = p[ind].proposal_boxes.area() / \ |
|
(p.image_size[0] * p.image_size[1]) |
|
max_score_count = score[ind, label].sigmoid() |
|
select_x_count = (p.proposal_boxes.tensor[ind, 0] + \ |
|
p.proposal_boxes.tensor[ind, 2]) / 2 / p.image_size[1] |
|
select_y_count = (p.proposal_boxes.tensor[ind, 1] + \ |
|
p.proposal_boxes.tensor[ind, 3]) / 2 / p.image_size[0] |
|
if self.debug: |
|
p.selected[ind] = label |
|
|
|
loss = loss / B |
|
storage.put_scalar('stats_l_image', loss.item()) |
|
if 'caption' in ann_type: |
|
caption_loss = caption_loss / B |
|
loss = loss + caption_loss |
|
storage.put_scalar('stats_l_caption', caption_loss.item()) |
|
if comm.is_main_process(): |
|
storage.put_scalar('pool_stats', img_box_count) |
|
storage.put_scalar('stats_select_size', select_size_count) |
|
storage.put_scalar('stats_select_x', select_x_count) |
|
storage.put_scalar('stats_select_y', select_y_count) |
|
storage.put_scalar('stats_max_label_score', max_score_count) |
|
|
|
return { |
|
'image_loss': loss * self.image_loss_weight, |
|
'loss_cls': score.new_zeros([1])[0], |
|
'loss_box_reg': score.new_zeros([1])[0]} |
|
|
|
|
|
def forward(self, x, classifier_info=(None,None,None)): |
|
""" |
|
enable classifier_info |
|
""" |
|
if x.dim() > 2: |
|
x = torch.flatten(x, start_dim=1) |
|
scores = [] |
|
|
|
if classifier_info[0] is not None: |
|
cls_scores = self.cls_score(x, classifier=classifier_info[0]) |
|
scores.append(cls_scores) |
|
else: |
|
cls_scores = self.cls_score(x) |
|
scores.append(cls_scores) |
|
|
|
if classifier_info[2] is not None: |
|
cap_cls = classifier_info[2] |
|
if self.sync_caption_batch: |
|
caption_scores = self.cls_score(x, classifier=cap_cls[:, :-1]) |
|
else: |
|
caption_scores = self.cls_score(x, classifier=cap_cls) |
|
scores.append(caption_scores) |
|
scores = torch.cat(scores, dim=1) |
|
|
|
proposal_deltas = self.bbox_pred(x) |
|
if self.with_softmax_prop: |
|
prop_score = self.prop_score(x) |
|
return scores, proposal_deltas, prop_score |
|
else: |
|
return scores, proposal_deltas |
|
|
|
|
|
def _caption_loss(self, score, classifier_info, idx, B): |
|
assert (classifier_info[2] is not None) |
|
assert self.add_image_box |
|
cls_and_cap_num = score.shape[1] |
|
cap_num = classifier_info[2].shape[0] |
|
score, caption_score = score.split( |
|
[cls_and_cap_num - cap_num, cap_num], dim=1) |
|
|
|
caption_score = caption_score[-1:] |
|
caption_target = caption_score.new_zeros( |
|
caption_score.shape) |
|
if self.sync_caption_batch: |
|
|
|
rank = comm.get_rank() |
|
global_idx = B * rank + idx |
|
assert (classifier_info[2][ |
|
global_idx, -1] - rank) ** 2 < 1e-8, \ |
|
'{} {} {} {} {}'.format( |
|
rank, global_idx, |
|
classifier_info[2][global_idx, -1], |
|
classifier_info[2].shape, |
|
classifier_info[2][:, -1]) |
|
caption_target[:, global_idx] = 1. |
|
else: |
|
assert caption_score.shape[1] == B |
|
caption_target[:, idx] = 1. |
|
caption_loss_img = F.binary_cross_entropy_with_logits( |
|
caption_score, caption_target, reduction='none') |
|
if self.sync_caption_batch: |
|
fg_mask = (caption_target > 0.5).float() |
|
assert (fg_mask.sum().item() - 1.) ** 2 < 1e-8, '{} {}'.format( |
|
fg_mask.shape, fg_mask) |
|
pos_loss = (caption_loss_img * fg_mask).sum() |
|
neg_loss = (caption_loss_img * (1. - fg_mask)).sum() |
|
caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss |
|
else: |
|
caption_loss_img = caption_loss_img.sum() |
|
return score, caption_loss_img |
|
|
|
|
|
def _wsddn_loss(self, score, prop_score, label): |
|
assert prop_score is not None |
|
loss = 0 |
|
final_score = score.sigmoid() * \ |
|
F.softmax(prop_score, dim=0) |
|
img_score = torch.clamp( |
|
torch.sum(final_score, dim=0), |
|
min=1e-10, max=1-1e-10) |
|
target = img_score.new_zeros(img_score.shape) |
|
target[label] = 1. |
|
loss += F.binary_cross_entropy(img_score, target) |
|
ind = final_score[:, label].argmax() |
|
return loss, ind |
|
|
|
|
|
def _max_score_loss(self, score, label): |
|
loss = 0 |
|
target = score.new_zeros(score.shape[1]) |
|
target[label] = 1. |
|
ind = score[:, label].argmax().item() |
|
loss += F.binary_cross_entropy_with_logits( |
|
score[ind], target, reduction='sum') |
|
return loss, ind |
|
|
|
|
|
def _min_loss_loss(self, score, label): |
|
loss = 0 |
|
target = score.new_zeros(score.shape) |
|
target[:, label] = 1. |
|
with torch.no_grad(): |
|
x = F.binary_cross_entropy_with_logits( |
|
score, target, reduction='none').sum(dim=1) |
|
ind = x.argmin().item() |
|
loss += F.binary_cross_entropy_with_logits( |
|
score[ind], target[0], reduction='sum') |
|
return loss, ind |
|
|
|
|
|
def _first_loss(self, score, label): |
|
loss = 0 |
|
target = score.new_zeros(score.shape[1]) |
|
target[label] = 1. |
|
ind = 0 |
|
loss += F.binary_cross_entropy_with_logits( |
|
score[ind], target, reduction='sum') |
|
return loss, ind |
|
|
|
|
|
def _image_loss(self, score, label): |
|
assert self.add_image_box |
|
target = score.new_zeros(score.shape[1]) |
|
target[label] = 1. |
|
ind = score.shape[0] - 1 |
|
loss = F.binary_cross_entropy_with_logits( |
|
score[ind], target, reduction='sum') |
|
return loss, ind |
|
|
|
|
|
def _max_size_loss(self, score, label, p): |
|
loss = 0 |
|
target = score.new_zeros(score.shape[1]) |
|
target[label] = 1. |
|
sizes = p.proposal_boxes.area() |
|
ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0 |
|
if self.softmax_weak_loss: |
|
loss += F.cross_entropy( |
|
score[ind:ind+1], |
|
score.new_tensor(label, dtype=torch.long).view(1), |
|
reduction='sum') |
|
else: |
|
loss += F.binary_cross_entropy_with_logits( |
|
score[ind], target, reduction='sum') |
|
return loss, ind |
|
|
|
|
|
|
|
def put_label_distribution(storage, hist_name, hist_counts, num_classes): |
|
""" |
|
""" |
|
ht_min, ht_max = 0, num_classes |
|
hist_edges = torch.linspace( |
|
start=ht_min, end=ht_max, steps=num_classes + 1, dtype=torch.float32) |
|
|
|
hist_params = dict( |
|
tag=hist_name, |
|
min=ht_min, |
|
max=ht_max, |
|
num=float(hist_counts.sum()), |
|
sum=float((hist_counts * torch.arange(len(hist_counts))).sum()), |
|
sum_squares=float(((hist_counts * torch.arange(len(hist_counts))) ** 2).sum()), |
|
bucket_limits=hist_edges[1:].tolist(), |
|
bucket_counts=hist_counts.tolist(), |
|
global_step=storage._iter, |
|
) |
|
storage._histograms.append(hist_params) |