Spaces:
Runtime error
Runtime error
""" | |
This file contains specific functions for computing losses on the RetinaNet | |
file | |
""" | |
import torch | |
from torch.nn import functional as F | |
from ..utils import concat_box_prediction_layers | |
from maskrcnn_benchmark.layers import smooth_l1_loss | |
from maskrcnn_benchmark.layers import SigmoidFocalLoss | |
from maskrcnn_benchmark.modeling.matcher import Matcher | |
from maskrcnn_benchmark.modeling.utils import cat | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist | |
from maskrcnn_benchmark.modeling.rpn.loss import RPNLossComputation | |
class RetinaNetLossComputation(RPNLossComputation): | |
""" | |
This class computes the RetinaNet loss. | |
""" | |
def __init__(self, proposal_matcher, box_coder, | |
generate_labels_func, | |
sigmoid_focal_loss, | |
bbox_reg_beta=0.11, | |
regress_norm=1.0): | |
""" | |
Arguments: | |
proposal_matcher (Matcher) | |
box_coder (BoxCoder) | |
""" | |
self.proposal_matcher = proposal_matcher | |
self.box_coder = box_coder | |
self.box_cls_loss_func = sigmoid_focal_loss | |
self.bbox_reg_beta = bbox_reg_beta | |
self.copied_fields = ['labels'] | |
self.generate_labels_func = generate_labels_func | |
self.discard_cases = ['between_thresholds'] | |
self.regress_norm = regress_norm | |
def __call__(self, anchors, box_cls, box_regression, targets): | |
""" | |
Arguments: | |
anchors (list[BoxList]) | |
box_cls (list[Tensor]) | |
box_regression (list[Tensor]) | |
targets (list[BoxList]) | |
Returns: | |
retinanet_cls_loss (Tensor) | |
retinanet_regression_loss (Tensor | |
""" | |
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] | |
labels, regression_targets = self.prepare_targets(anchors, targets) | |
N = len(labels) | |
box_cls, box_regression = \ | |
concat_box_prediction_layers(box_cls, box_regression) | |
labels = torch.cat(labels, dim=0) | |
regression_targets = torch.cat(regression_targets, dim=0) | |
pos_inds = torch.nonzero(labels > 0).squeeze(1) | |
retinanet_regression_loss = smooth_l1_loss( | |
box_regression[pos_inds], | |
regression_targets[pos_inds], | |
beta=self.bbox_reg_beta, | |
size_average=False, | |
) / (max(1, pos_inds.numel() * self.regress_norm)) | |
labels = labels.int() | |
retinanet_cls_loss = self.box_cls_loss_func( | |
box_cls, | |
labels | |
) / (pos_inds.numel() + N) | |
return retinanet_cls_loss, retinanet_regression_loss | |
def generate_retinanet_labels(matched_targets): | |
labels_per_image = matched_targets.get_field("labels") | |
return labels_per_image | |
def make_retinanet_loss_evaluator(cfg, box_coder): | |
matcher = Matcher( | |
cfg.MODEL.RETINANET.FG_IOU_THRESHOLD, | |
cfg.MODEL.RETINANET.BG_IOU_THRESHOLD, | |
allow_low_quality_matches=True, | |
) | |
sigmoid_focal_loss = SigmoidFocalLoss( | |
cfg.MODEL.RETINANET.LOSS_GAMMA, | |
cfg.MODEL.RETINANET.LOSS_ALPHA | |
) | |
loss_evaluator = RetinaNetLossComputation( | |
matcher, | |
box_coder, | |
generate_retinanet_labels, | |
sigmoid_focal_loss, | |
bbox_reg_beta = cfg.MODEL.RETINANET.BBOX_REG_BETA, | |
regress_norm = cfg.MODEL.RETINANET.BBOX_REG_WEIGHT, | |
) | |
return loss_evaluator | |