|
|
|
import inspect |
|
import logging |
|
import numpy as np |
|
from typing import Dict, List, Optional, Tuple |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.config import configurable |
|
from detectron2.layers import ShapeSpec, nonzero_tuple |
|
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou |
|
from detectron2.utils.events import get_event_storage |
|
from detectron2.utils.registry import Registry |
|
|
|
from detectron2.modeling.box_regression import Box2BoxTransform |
|
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference |
|
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads |
|
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient |
|
from detectron2.modeling.roi_heads.box_head import build_box_head |
|
|
|
from .detic_fast_rcnn import DeticFastRCNNOutputLayers |
|
from ..debug import debug_second_stage |
|
|
|
from torch.cuda.amp import autocast |
|
|
|
@ROI_HEADS_REGISTRY.register() |
|
class CustomRes5ROIHeads(Res5ROIHeads): |
|
@configurable |
|
def __init__(self, **kwargs): |
|
cfg = kwargs.pop('cfg') |
|
super().__init__(**kwargs) |
|
stage_channel_factor = 2 ** 3 |
|
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor |
|
|
|
self.with_image_labels = cfg.WITH_IMAGE_LABELS |
|
self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS |
|
self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX |
|
self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP |
|
self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE |
|
self.box_predictor = DeticFastRCNNOutputLayers( |
|
cfg, ShapeSpec(channels=out_channels, height=1, width=1) |
|
) |
|
|
|
self.save_debug = cfg.SAVE_DEBUG |
|
self.save_debug_path = cfg.SAVE_DEBUG_PATH |
|
if self.save_debug: |
|
self.debug_show_name = cfg.DEBUG_SHOW_NAME |
|
self.vis_thresh = cfg.VIS_THRESH |
|
self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to( |
|
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) |
|
self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to( |
|
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) |
|
self.bgr = (cfg.INPUT.FORMAT == 'BGR') |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
ret = super().from_config(cfg, input_shape) |
|
ret['cfg'] = cfg |
|
return ret |
|
|
|
def forward(self, images, features, proposals, targets=None, |
|
ann_type='box', classifier_info=(None,None,None)): |
|
''' |
|
enable debug and image labels |
|
classifier_info is shared across the batch |
|
''' |
|
if not self.save_debug: |
|
del images |
|
|
|
if self.training: |
|
if ann_type in ['box']: |
|
proposals = self.label_and_sample_proposals( |
|
proposals, targets) |
|
else: |
|
proposals = self.get_top_proposals(proposals) |
|
|
|
proposal_boxes = [x.proposal_boxes for x in proposals] |
|
box_features = self._shared_roi_transform( |
|
[features[f] for f in self.in_features], proposal_boxes |
|
) |
|
predictions = self.box_predictor( |
|
box_features.mean(dim=[2, 3]), |
|
classifier_info=classifier_info) |
|
|
|
if self.add_feature_to_prop: |
|
feats_per_image = box_features.mean(dim=[2, 3]).split( |
|
[len(p) for p in proposals], dim=0) |
|
for feat, p in zip(feats_per_image, proposals): |
|
p.feat = feat |
|
|
|
if self.training: |
|
del features |
|
if (ann_type != 'box'): |
|
image_labels = [x._pos_category_ids for x in targets] |
|
losses = self.box_predictor.image_label_losses( |
|
predictions, proposals, image_labels, |
|
classifier_info=classifier_info, |
|
ann_type=ann_type) |
|
else: |
|
losses = self.box_predictor.losses( |
|
(predictions[0], predictions[1]), proposals) |
|
if self.with_image_labels: |
|
assert 'image_loss' not in losses |
|
losses['image_loss'] = predictions[0].new_zeros([1])[0] |
|
if self.save_debug: |
|
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean |
|
if ann_type != 'box': |
|
image_labels = [x._pos_category_ids for x in targets] |
|
else: |
|
image_labels = [[] for x in targets] |
|
debug_second_stage( |
|
[denormalizer(x.clone()) for x in images], |
|
targets, proposals=proposals, |
|
save_debug=self.save_debug, |
|
debug_show_name=self.debug_show_name, |
|
vis_thresh=self.vis_thresh, |
|
image_labels=image_labels, |
|
save_debug_path=self.save_debug_path, |
|
bgr=self.bgr) |
|
return proposals, losses |
|
else: |
|
pred_instances, _ = self.box_predictor.inference(predictions, proposals) |
|
pred_instances = self.forward_with_given_boxes(features, pred_instances) |
|
if self.save_debug: |
|
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean |
|
debug_second_stage( |
|
[denormalizer(x.clone()) for x in images], |
|
pred_instances, proposals=proposals, |
|
save_debug=self.save_debug, |
|
debug_show_name=self.debug_show_name, |
|
vis_thresh=self.vis_thresh, |
|
save_debug_path=self.save_debug_path, |
|
bgr=self.bgr) |
|
return pred_instances, {} |
|
|
|
def get_top_proposals(self, proposals): |
|
for i in range(len(proposals)): |
|
proposals[i].proposal_boxes.clip(proposals[i].image_size) |
|
proposals = [p[:self.ws_num_props] for p in proposals] |
|
for i, p in enumerate(proposals): |
|
p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() |
|
if self.add_image_box: |
|
proposals[i] = self._add_image_box(p) |
|
return proposals |
|
|
|
def _add_image_box(self, p, use_score=False): |
|
image_box = Instances(p.image_size) |
|
n = 1 |
|
h, w = p.image_size |
|
if self.image_box_size < 1.0: |
|
f = self.image_box_size |
|
image_box.proposal_boxes = Boxes( |
|
p.proposal_boxes.tensor.new_tensor( |
|
[w * (1. - f) / 2., |
|
h * (1. - f) / 2., |
|
w * (1. - (1. - f) / 2.), |
|
h * (1. - (1. - f) / 2.)] |
|
).view(n, 4)) |
|
else: |
|
image_box.proposal_boxes = Boxes( |
|
p.proposal_boxes.tensor.new_tensor( |
|
[0, 0, w, h]).view(n, 4)) |
|
if use_score: |
|
image_box.scores = \ |
|
p.objectness_logits.new_ones(n) |
|
image_box.pred_classes = \ |
|
p.objectness_logits.new_zeros(n, dtype=torch.long) |
|
image_box.objectness_logits = \ |
|
p.objectness_logits.new_ones(n) |
|
else: |
|
image_box.objectness_logits = \ |
|
p.objectness_logits.new_ones(n) |
|
return Instances.cat([p, image_box]) |