import copy import math from typing import Optional, List import torch from torch import nn, Tensor import torch.nn.functional as F from detectron2.modeling.poolers import ROIPooler, cat from detectron2.structures import Boxes, pairwise_iou from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, get_norm from detectron2.modeling.matcher import Matcher from .rec_stage import REC_STAGE _DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16) def _get_src_permutation_idx(indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx class DynamicHead(nn.Module): def __init__(self, cfg, roi_input_shape): super().__init__() # Build RoI. box_pooler = self._init_box_pooler(cfg, roi_input_shape) self.box_pooler = box_pooler box_pooler_rec = self._init_box_pooler_rec(cfg, roi_input_shape) self.box_pooler_rec = box_pooler_rec # Build heads. num_classes = cfg.MODEL.SWINTS.NUM_CLASSES self.hidden_dim = cfg.MODEL.SWINTS.HIDDEN_DIM dim_feedforward = cfg.MODEL.SWINTS.DIM_FEEDFORWARD nhead = cfg.MODEL.SWINTS.NHEADS dropout = cfg.MODEL.SWINTS.DROPOUT activation = cfg.MODEL.SWINTS.ACTIVATION self.train_num_proposal = cfg.MODEL.SWINTS.NUM_PROPOSALS self.num_heads = cfg.MODEL.SWINTS.NUM_HEADS rcnn_head = RCNNHead(cfg, self.hidden_dim, num_classes, dim_feedforward, nhead, dropout, activation) self.head_series = _get_clones(rcnn_head, self.num_heads) self.return_intermediate = cfg.MODEL.SWINTS.DEEP_SUPERVISION self.cfg =cfg # Build recognition heads self.rec_stage = REC_STAGE(cfg, self.hidden_dim, num_classes, dim_feedforward, nhead, dropout, activation) self.cnn = nn.Sequential( nn.Conv2d(self.hidden_dim, self.hidden_dim,3,1,1), nn.BatchNorm2d(self.hidden_dim), nn.ReLU(True), nn.Conv2d(self.hidden_dim, self.hidden_dim,3,1,1), nn.BatchNorm2d(self.hidden_dim), nn.ReLU(True), ) #DC self.conv = nn.ModuleList([ nn.Sequential( nn.Conv2d(self.hidden_dim, self.hidden_dim,3,1,2,2), nn.BatchNorm2d(self.hidden_dim), nn.ReLU(True), nn.Conv2d(self.hidden_dim, self.hidden_dim,3,1,4,4), nn.BatchNorm2d(self.hidden_dim), nn.ReLU(True), nn.Conv2d(self.hidden_dim, self.hidden_dim,3,1,1), nn.BatchNorm2d(self.hidden_dim), nn.ReLU(True),) for i in range(4) ] ) # Init parameters. self.num_classes = num_classes prior_prob = cfg.MODEL.SWINTS.PRIOR_PROB self.bias_value = -math.log((1 - prior_prob) / prior_prob) self._reset_parameters() def _reset_parameters(self): # init all parameters. for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) # initialize the bias for focal loss. if p.shape[-1] == self.num_classes: nn.init.constant_(p, self.bias_value) @staticmethod def _init_box_pooler(cfg, input_shape): in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE # If StandardROIHeads is applied on multiple feature maps (as in FPN), # then we share the same predictors and therefore the channel counts must be the same in_channels = [input_shape[f].channels for f in in_features] # Check all channel counts are equal assert len(set(in_channels)) == 1, in_channels box_pooler = ROIPooler( output_size=pooler_resolution, scales=pooler_scales, sampling_ratio=sampling_ratio, pooler_type=pooler_type, ) return box_pooler @staticmethod def _init_box_pooler_rec(cfg, input_shape): in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES pooler_resolution = cfg.MODEL.REC_HEAD.POOLER_RESOLUTION pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE # If StandardROIHeads is applied on multiple feature maps (as in FPN), # then we share the same predictors and therefore the channel counts must be the same in_channels = [input_shape[f].channels for f in in_features] # Check all channel counts are equal assert len(set(in_channels)) == 1, in_channels box_pooler = ROIPooler( output_size=pooler_resolution, scales= pooler_scales, sampling_ratio=sampling_ratio, pooler_type=pooler_type, ) return box_pooler def extra_rec_feat(self, matcher, mask_encoding, targets, N, bboxes, class_logits, pred_bboxes, mask_logits, proposal_features, features): gt_masks = list() gt_boxes = list() proposal_boxes_pred = list() masks_pred = list() pred_mask = mask_logits.detach() N, nr_boxes = bboxes.shape[:2] if targets: output = {'pred_logits': class_logits, 'pred_boxes': pred_bboxes, 'pred_masks': mask_logits} indices = matcher(output, targets, mask_encoding) idx = _get_src_permutation_idx(indices) target_rec = torch.cat([t['rec'][i] for t, (_, i) in zip(targets, indices)], dim=0) target_rec = target_rec.repeat(2,1) else: idx = None scores = torch.sigmoid(class_logits) labels = torch.arange(2, device=bboxes.device).\ unsqueeze(0).repeat(self.train_num_proposal, 1).flatten(0, 1) inter_class_logits = [] inter_pred_bboxes = [] inter_pred_masks = [] inter_pred_label = [] for b in range(N): if targets: gt_boxes.append(Boxes(targets[b]['boxes_xyxy'][indices[b][1]])) gt_masks.append(targets[b]['gt_masks'][indices[b][1]]) proposal_boxes_pred.append(Boxes(bboxes[b][indices[b][0]])) tmp_mask = mask_encoding.decoder(pred_mask[b]).view(-1,28,28) tmp_mask = tmp_mask[indices[b][0]] tmp_mask2 = torch.full_like(tmp_mask,0).cuda() tmp_mask2[tmp_mask>0.4]=1 masks_pred.append(tmp_mask2) else: # post_processing num_proposals = self.cfg.MODEL.SWINTS.TEST_NUM_PROPOSALS scores_per_image, topk_indices = scores[b].flatten(0, 1).topk(num_proposals, sorted=False) labels_per_image = labels[topk_indices] box_pred_per_image = bboxes[b].view(-1, 1, 4).repeat(1, 2, 1).view(-1, 4) box_pred_per_image = box_pred_per_image[topk_indices] mask_pred_per_image = mask_logits.view(-1, self.cfg.MODEL.SWINTS.MASK_DIM) mask_pred_per_image = mask_encoding.decoder(mask_pred_per_image, is_train=False) mask_pred_per_image = mask_pred_per_image.view(-1, 1, 28, 28) n, c, w, h = mask_pred_per_image.size() mask_pred_per_image = torch.repeat_interleave(mask_pred_per_image,2,1).view(-1, c, w, h) mask_pred_per_image = mask_pred_per_image[topk_indices] proposal_features = proposal_features[b].view(-1, 1, self.hidden_dim).repeat(1, 2, 1).view(-1, self.hidden_dim) proposal_features = proposal_features[topk_indices] proposal_boxes_pred.append(Boxes(box_pred_per_image)) gt_masks.append(mask_pred_per_image) inter_class_logits.append(scores_per_image) inter_pred_bboxes.append(box_pred_per_image) inter_pred_masks.append(mask_pred_per_image) inter_pred_label.append(labels_per_image) # get recognition roi region if targets: gt_roi_features = self.box_pooler_rec(features, gt_boxes) pred_roi_features = self.box_pooler_rec(features, proposal_boxes_pred) masks_pred = torch.cat(masks_pred).cuda() gt_masks = torch.cat(gt_masks).cuda() rec_map = torch.cat((gt_roi_features,pred_roi_features),0) gt_masks = torch.cat((gt_masks,masks_pred),0) else: rec_map = self.box_pooler_rec(features, proposal_boxes_pred) gt_masks = torch.cat(gt_masks).cuda() nr_boxes = rec_map.shape[0] if targets: rec_map = rec_map[:self.cfg.MODEL.REC_HEAD.BATCH_SIZE] else: gt_masks_b = torch.full_like(gt_masks,0).cuda() gt_masks_b[gt_masks>0.4]=1 gt_masks_b = gt_masks_b.squeeze() gt_masks = gt_masks_b del gt_masks_b if targets: return proposal_features, gt_masks[:self.cfg.MODEL.REC_HEAD.BATCH_SIZE], idx, rec_map, target_rec[:self.cfg.MODEL.REC_HEAD.BATCH_SIZE] else: return inter_class_logits, inter_pred_bboxes, inter_pred_masks, inter_pred_label, proposal_features, gt_masks, idx, rec_map, nr_boxes def forward(self, features, init_bboxes, init_features, targets = None, mask_encoding = None, matcher=None): inter_class_logits = [] inter_pred_bboxes = [] inter_pred_masks = [] inter_pred_label = [] bs = len(features[0]) bboxes = init_bboxes proposal_features = init_features.clone() for i_idx in range(len(features)): features[i_idx] = self.conv[i_idx](features[i_idx]) + features[i_idx] for i, rcnn_head in enumerate(self.head_series): class_logits, pred_bboxes, proposal_features, mask_logits = rcnn_head(features, bboxes, proposal_features, self.box_pooler) if self.return_intermediate: inter_class_logits.append(class_logits) inter_pred_bboxes.append(pred_bboxes) inter_pred_masks.append(mask_logits) bboxes = pred_bboxes.detach() # extract recognition feature. N, nr_boxes = bboxes.shape[:2] if targets: proposal_features, gt_masks, idx, rec_map, target_rec = \ self.extra_rec_feat(matcher, mask_encoding, targets, N, bboxes, class_logits, pred_bboxes, mask_logits, proposal_features, features) else: inter_class_logits, inter_pred_bboxes, inter_pred_masks, inter_pred_label, proposal_features, gt_masks, idx, rec_map, nr_boxes = \ self.extra_rec_feat(matcher, mask_encoding, targets, N, bboxes, class_logits, pred_bboxes, mask_logits, proposal_features, features) rec_map = self.cnn(rec_map) rec_proposal_features = proposal_features.clone() if targets: rec_result = self.rec_stage(rec_map, rec_proposal_features, gt_masks, N, nr_boxes, idx, target_rec) else: rec_result = self.rec_stage(rec_map, rec_proposal_features, gt_masks, N, nr_boxes) rec_result = torch.tensor(rec_result) if self.return_intermediate: return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes), torch.stack(inter_pred_masks), rec_result return class_logits[None], pred_bboxes[None], mask_logits[None] class RCNNHead(nn.Module): def __init__(self, cfg, d_model, num_classes, dim_feedforward=2048, nhead=8, dropout=0.1, activation="relu", scale_clamp: float = _DEFAULT_SCALE_CLAMP, bbox_weights=(2.0, 2.0, 1.0, 1.0)): super().__init__() self.d_model = d_model # dynamic. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.inst_interact = DynamicConv(cfg) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = nn.ELU(inplace=True) # cls. num_cls = cfg.MODEL.SWINTS.NUM_CLS cls_module = list() for _ in range(num_cls): cls_module.append(nn.Linear(d_model, d_model, False)) cls_module.append(nn.LayerNorm(d_model)) cls_module.append(nn.ELU(inplace=True)) self.cls_module = nn.ModuleList(cls_module) # reg. num_reg = cfg.MODEL.SWINTS.NUM_REG reg_module = list() for _ in range(num_reg): reg_module.append(nn.Linear(d_model, d_model, False)) reg_module.append(nn.LayerNorm(d_model)) reg_module.append(nn.ELU(inplace=True)) self.reg_module = nn.ModuleList(reg_module) # mask. num_mask = cfg.MODEL.SWINTS.NUM_MASK mask_module = list() for _ in range(num_mask): mask_module.append(nn.Linear(d_model, d_model, False)) mask_module.append(nn.LayerNorm(d_model)) mask_module.append(nn.ELU(inplace=True)) self.mask_module = nn.ModuleList(mask_module) self.mask_logits = nn.Linear(d_model, cfg.MODEL.SWINTS.MASK_DIM) # pred. self.class_logits = nn.Linear(d_model, num_classes) self.bboxes_delta = nn.Linear(d_model, 4) self.scale_clamp = scale_clamp self.bbox_weights = bbox_weights def forward(self, features, bboxes, pro_features, pooler): """ :param bboxes: (N, nr_boxes, 4) :param pro_features: (N, nr_boxes, d_model) """ N, nr_boxes = bboxes.shape[:2] # roi_feature. proposal_boxes = list() for b in range(N): proposal_boxes.append(Boxes(bboxes[b])) roi_features = pooler(features, proposal_boxes) roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1) # self_att. pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2) pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0] pro_features = pro_features + self.dropout1(pro_features2) del pro_features2 pro_features = self.norm1(pro_features) # inst_interact. pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model) pro_features2 = self.inst_interact(pro_features, roi_features) pro_features = pro_features + self.dropout2(pro_features2) del pro_features2 obj_features = self.norm2(pro_features) # obj_feature. obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features)))) obj_features = obj_features + self.dropout3(obj_features2) del obj_features2 obj_features = self.norm3(obj_features) fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1) cls_feature = fc_feature.clone() reg_feature = fc_feature.clone() mask_feature = fc_feature.clone() del fc_feature for mask_layer in self.mask_module: mask_feature = mask_layer(mask_feature) mask_logits = self.mask_logits(mask_feature) del mask_feature for cls_layer in self.cls_module: cls_feature = cls_layer(cls_feature) for reg_layer in self.reg_module: reg_feature = reg_layer(reg_feature) class_logits = self.class_logits(cls_feature) bboxes_deltas = self.bboxes_delta(reg_feature) del cls_feature del reg_feature pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4)) return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features, mask_logits.view(N, nr_boxes, -1) def apply_deltas(self, deltas, boxes): """ Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. Args: deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. deltas[i] represents k potentially different class-specific box transformations for the single box boxes[i]. boxes (Tensor): boxes to transform, of shape (N, 4) """ boxes = boxes.to(deltas.dtype) widths = boxes[:, 2] - boxes[:, 0] heights = boxes[:, 3] - boxes[:, 1] ctr_x = boxes[:, 0] + 0.5 * widths ctr_y = boxes[:, 1] + 0.5 * heights wx, wy, ww, wh = self.bbox_weights dx = deltas[:, 0::4] / wx dy = deltas[:, 1::4] / wy dw = deltas[:, 2::4] / ww dh = deltas[:, 3::4] / wh # Prevent sending too large values into torch.exp() dw = torch.clamp(dw, max=self.scale_clamp) dh = torch.clamp(dh, max=self.scale_clamp) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] pred_boxes = torch.zeros_like(deltas) pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1 pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1 pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2 pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2 return pred_boxes class DynamicConv(nn.Module): def __init__(self, cfg): super().__init__() self.hidden_dim = cfg.MODEL.SWINTS.HIDDEN_DIM self.dim_dynamic = cfg.MODEL.SWINTS.DIM_DYNAMIC self.num_dynamic = cfg.MODEL.SWINTS.NUM_DYNAMIC self.num_params = self.hidden_dim * self.dim_dynamic self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params) self.norm1 = nn.LayerNorm(self.dim_dynamic) self.norm2 = nn.LayerNorm(self.hidden_dim) self.activation = nn.ELU(inplace=True) pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION num_output = self.hidden_dim * pooler_resolution ** 2 self.out_layer = nn.Linear(num_output, self.hidden_dim) self.norm3 = nn.LayerNorm(self.hidden_dim) def forward(self, pro_features, roi_features): ''' pro_features: (1, N * nr_boxes, self.d_model) roi_features: (49, N * nr_boxes, self.d_model) ''' features = roi_features.permute(1, 0, 2) parameters = self.dynamic_layer(pro_features).permute(1, 0, 2) param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic) param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim) del parameters features = torch.bmm(features, param1) del param1 features = self.norm1(features) features = self.activation(features) features = torch.bmm(features, param2) del param2 features = self.norm2(features) features = self.activation(features) features = features.flatten(1) features = self.out_layer(features) features = self.norm3(features) features = self.activation(features) return features def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)])