Spaces:
Runtime error
Runtime error
File size: 5,301 Bytes
4ea50ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import math
import torch
import torch.nn.functional as F
from torch import nn
from .inference import make_retinanet_postprocessor
from .loss import make_retinanet_loss_evaluator
from ..anchor_generator import make_anchor_generator_retinanet
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
class RetinaNetHead(torch.nn.Module):
"""
Adds a RetinNet head with classification and regression heads
"""
def __init__(self, cfg, in_channels):
"""
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RetinaNetHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1
num_anchors = len(cfg.MODEL.RETINANET.ASPECT_RATIOS) \
* cfg.MODEL.RETINANET.SCALES_PER_OCTAVE
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.RETINANET.NUM_CONVS):
cls_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
cls_tower.append(nn.ReLU())
bbox_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
bbox_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, stride=1,
padding=1
)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, stride=1,
padding=1
)
# Initialization
for modules in [self.cls_tower, self.bbox_tower, self.cls_logits,
self.bbox_pred]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# retinanet_bias_init
prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
def forward(self, x):
logits = []
bbox_reg = []
for feature in x:
logits.append(self.cls_logits(self.cls_tower(feature)))
bbox_reg.append(self.bbox_pred(self.bbox_tower(feature)))
return logits, bbox_reg
class RetinaNetModule(torch.nn.Module):
"""
Module for RetinaNet computation. Takes feature maps from the backbone and
RetinaNet outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg, in_channels):
super(RetinaNetModule, self).__init__()
self.cfg = cfg.clone()
anchor_generator = make_anchor_generator_retinanet(cfg)
head = RetinaNetHead(cfg, in_channels)
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
box_selector_test = make_retinanet_postprocessor(cfg, box_coder, is_train=False)
loss_evaluator = make_retinanet_loss_evaluator(cfg, box_coder)
self.anchor_generator = anchor_generator
self.head = head
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
if self.training:
return self._forward_train(anchors, box_cls, box_regression, targets)
else:
return self._forward_test(anchors, box_cls, box_regression)
def _forward_train(self, anchors, box_cls, box_regression, targets):
loss_box_cls, loss_box_reg = self.loss_evaluator(
anchors, box_cls, box_regression, targets
)
losses = {
"loss_retina_cls": loss_box_cls,
"loss_retina_reg": loss_box_reg,
}
return anchors, losses
def _forward_test(self, anchors, box_cls, box_regression):
boxes = self.box_selector_test(anchors, box_cls, box_regression)
return boxes, {}
def build_retinanet(cfg, in_channels):
return RetinaNetModule(cfg, in_channels)
|