import copy import warnings from typing import List, Optional, Tuple, Union, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine import ConfigDict from mmengine.structures import InstanceData from torch import Tensor from mmdet.models import BaseDetector, TwoStageDetector, StandardRoIHead, SinePositionalEncoding, FCNMaskHead, \ BaseRoIHead from mmdet.models.task_modules import SamplingResult from mmdet.models.utils import multi_apply, unpack_gt_instances, empty_instances from mmdet.structures import SampleList, DetDataSample from mmdet.structures.bbox import bbox2roi from mmdet.structures.mask import mask_target from mmdet.utils import InstanceList, reduce_mean, OptMultiConfig from mmpl.registry import MODELS, TASK_UTILS from mmengine.model import BaseModel, BaseModule from einops import rearrange, repeat from mmpl.utils import ConfigType, OptConfigType from mmdet.models.dense_heads import Mask2FormerHead from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead @MODELS.register_module() class SAMInstanceHead(Mask2FormerHead): def __init__( self, num_things_classes: int = 1, num_stuff_classes: int = 0, prompt_neck: ConfigType = ..., with_iou: bool = False, with_multiscale: bool = False, with_sincos: bool = False, with_res_imgfeat: bool = False, loss_cls: ConfigType = dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0, reduction='mean', class_weight=[1.0] * 133 + [0.1]), loss_mask: ConfigType = dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=5.0), loss_dice: ConfigType = dict( type='DiceLoss', use_sigmoid=True, activate=True, reduction='mean', naive_dice=True, eps=1.0, loss_weight=5.0), train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), **kwargs ): super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) self.num_things_classes = num_things_classes self.num_stuff_classes = num_stuff_classes self.num_classes = self.num_things_classes + self.num_stuff_classes self.with_iou = with_iou self.with_multiscale = with_multiscale self.with_sincos = with_sincos self.with_res_imgfeat = with_res_imgfeat # self.num_transformer_feat_level = num_transformer_feat_level # self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads # self.num_transformer_decoder_layers = transformer_decoder.num_layers # assert pixel_decoder.encoder.layer_cfg. \ # self_attn_cfg.num_levels == num_transformer_feat_level # pixel_decoder_ = copy.deepcopy(pixel_decoder) # pixel_decoder_.update( # in_channels=in_channels, # feat_channels=feat_channels, # out_channels=out_channels) # self.pixel_decoder = MODELS.build(pixel_decoder_) # self.transformer_decoder = Mask2FormerTransformerDecoder( # **transformer_decoder) # self.decoder_embed_dims = self.transformer_decoder.embed_dims # # self.decoder_input_projs = ModuleList() # # from low resolution to high resolution # for _ in range(num_transformer_feat_level): # if (self.decoder_embed_dims != feat_channels # or enforce_decoder_input_project): # self.decoder_input_projs.append( # Conv2d( # feat_channels, self.decoder_embed_dims, kernel_size=1)) # else: # self.decoder_input_projs.append(nn.Identity()) # self.decoder_positional_encoding = SinePositionalEncoding( # **positional_encoding) # self.query_embed = nn.Embedding(self.num_queries, feat_channels) # self.query_feat = nn.Embedding(self.num_queries, feat_channels) # # from low resolution to high resolution # self.level_embed = nn.Embedding(self.num_transformer_feat_level, # feat_channels) # # self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) # self.mask_embed = nn.Sequential( # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), # nn.Linear(feat_channels, out_channels)) self.prompt_neck = MODELS.build(prompt_neck) self.num_queries = self.prompt_neck.num_queries self.per_query_point = self.prompt_neck.per_query_point out_channels = self.prompt_neck.out_channels self.cls_embed = nn.Sequential( nn.Linear(out_channels, out_channels // 2), nn.ReLU(inplace=True), nn.Linear(out_channels // 2, self.num_classes + 1) ) if self.with_sincos: self.point_emb = nn.Sequential( nn.Linear(out_channels, out_channels), nn.ReLU(inplace=True), nn.Linear(out_channels, out_channels), nn.ReLU(inplace=True), nn.Linear(out_channels, self.per_query_point * out_channels*2) ) else: self.point_emb = nn.Sequential( nn.Linear(out_channels, out_channels), nn.ReLU(inplace=True), nn.Linear(out_channels, out_channels), nn.ReLU(inplace=True), nn.Linear(out_channels, self.per_query_point * out_channels) ) if self.with_res_imgfeat: self.res_imgfeat = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor=2), ConvModule( out_channels, out_channels, kernel_size=3, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg ) ) self.test_cfg = test_cfg self.train_cfg = train_cfg if train_cfg: self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) self.sampler = TASK_UTILS.build( self.train_cfg['sampler'], default_args=dict(context=self)) self.num_points = self.train_cfg.get('num_points', 12544) self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) self.importance_sample_ratio = self.train_cfg.get( 'importance_sample_ratio', 0.75) self.class_weight = loss_cls.class_weight self.loss_cls = MODELS.build(loss_cls) self.loss_mask = MODELS.build(loss_mask) self.loss_dice = MODELS.build(loss_dice) def forward(self, x: List[Tensor], batch_data_samples: SampleList, sam ) -> Tuple[List[Tensor]]: """Forward function. Args: x (list[Tensor]): Multi scale Features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: tuple[list[Tensor]]: A tuple contains two elements. - cls_pred_list (list[Tensor)]: Classification logits \ for each decoder layer. Each is a 3D-tensor with shape \ (batch_size, num_queries, cls_out_channels). \ Note `cls_out_channels` should includes background. - mask_pred_list (list[Tensor]): Mask logits for each \ decoder layer. Each with shape (batch_size, num_queries, \ h, w). """ batch_img_metas = [ data_sample.metainfo for data_sample in batch_data_samples ] batch_size = len(batch_img_metas) decoder_out, query_feat_list, res_img_feat = self.prompt_neck(x) if self.with_multiscale: cls_pred_list = [self.cls_embed(query_feat) for query_feat in query_feat_list] else: # shape (batch_size, num_queries, c) cls_pred_list = [self.cls_embed(decoder_out)] # shape (batch_size, num_queries, c) point_emb = self.point_emb(decoder_out) # shape (batch_size, num_queries, per_query_point, c) point_emb = point_emb.view(batch_size, self.num_queries, self.per_query_point, -1) img_seg_feat = x[0] point_emb = rearrange(point_emb, 'b n p c -> (b n) p c') if self.with_sincos: point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( point_emb.shape[0], -1, *img_seg_feat.shape[-2:] ) img_embeddings = torch.repeat_interleave(img_seg_feat, self.num_queries, dim=0) img_pe = sam.prompt_encoder.get_dense_pe() img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) if self.with_res_imgfeat: res_img_feat = self.res_imgfeat(res_img_feat) res_img_feat = torch.repeat_interleave(res_img_feat, self.num_queries, dim=0) else: res_img_feat = None low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( image_embeddings=img_embeddings, image_pe=img_pe, sparse_prompt_embeddings=point_emb, dense_prompt_embeddings=nomask_dense_embeddings, multimask_output=False, res_img_feat=res_img_feat, ) mask_pred = rearrange(low_res_masks.squeeze(1), '(b n) h w -> b n h w', b=batch_size) # optional # if self.with_iou: # iou_predictions = iou_predictions.view(batch_size, self.num_queries, -1) # cls_pred = cls_pred * iou_predictions if self.with_multiscale: mask_pred_list = [mask_pred] * len(cls_pred_list) else: mask_pred_list = [mask_pred] return cls_pred_list, mask_pred_list def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList, sam ) -> Tuple[Tensor]: """Test without augmentaton. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: tuple[Tensor]: A tuple contains two tensors. - mask_cls_results (Tensor): Mask classification logits,\ shape (batch_size, num_queries, cls_out_channels). Note `cls_out_channels` should includes background. - mask_pred_results (Tensor): Mask logits, shape \ (batch_size, num_queries, h, w). """ batch_img_metas = [ data_sample.metainfo for data_sample in batch_data_samples ] all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) mask_cls_results = all_cls_scores[-1] mask_pred_results = all_mask_preds[-1] # upsample masks img_shape = batch_img_metas[0]['batch_input_shape'] mask_pred_results = F.interpolate( mask_pred_results, size=(img_shape[0], img_shape[1]), mode='bilinear', align_corners=False) return mask_cls_results, mask_pred_results def loss( self, x: Tuple[Tensor], batch_data_samples: SampleList, sam, ) -> Dict[str, Tensor]: """Perform forward propagation and loss calculation of the panoptic head on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict[str, Tensor]: a dictionary of loss components """ batch_img_metas = [] batch_gt_instances = [] batch_gt_semantic_segs = [] for data_sample in batch_data_samples: batch_img_metas.append(data_sample.metainfo) batch_gt_instances.append(data_sample.gt_instances) if 'gt_sem_seg' in data_sample: batch_gt_semantic_segs.append(data_sample.gt_sem_seg) else: batch_gt_semantic_segs.append(None) # forward all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) # preprocess ground truth batch_gt_instances = self.preprocess_gt(batch_gt_instances, batch_gt_semantic_segs) # loss losses = self.loss_by_feat(all_cls_scores, all_mask_preds, batch_gt_instances, batch_img_metas) return losses @MODELS.register_module() class SAMAnchorInstanceHead(TwoStageDetector): def __init__( self, sam_head=True, neck: OptConfigType = None, rpn_head: OptConfigType = None, roi_head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, **kwargs ): super(TwoStageDetector, self).__init__() self.neck = MODELS.build(neck) self.sam_head = sam_head if rpn_head is not None: rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head_ = rpn_head.copy() rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) rpn_head_num_classes = rpn_head_.get('num_classes', None) if rpn_head_num_classes is None: rpn_head_.update(num_classes=1) else: if rpn_head_num_classes != 1: warnings.warn( 'The `num_classes` should be 1 in RPN, but get ' f'{rpn_head_num_classes}, please set ' 'rpn_head.num_classes = 1 in your config file.') rpn_head_.update(num_classes=1) self.rpn_head = MODELS.build(rpn_head_) if roi_head is not None: # update train and test cfg here for now # TODO: refactor assigner & sampler rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) self.roi_head = MODELS.build(roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg def extract_feat(self, x): x = self.neck(x) return x def loss(self, batch_inputs, batch_data_samples: SampleList, sam ) -> dict: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. batch_data_samples (List[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict: A dictionary of loss components """ x = self.extract_feat(batch_inputs) img_seg_feat = batch_inputs[0] losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_data_samples = copy.deepcopy(batch_data_samples) # set cat_id of gt_labels to 0 in RPN for data_sample in rpn_data_samples: data_sample.gt_instances.labels = \ torch.zeros_like(data_sample.gt_instances.labels) rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( x, rpn_data_samples, proposal_cfg=proposal_cfg) # avoid get same name with roi_head loss keys = rpn_losses.keys() for key in list(keys): if 'loss' in key and 'rpn' not in key: rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) losses.update(rpn_losses) else: assert batch_data_samples[0].get('proposals', None) is not None # use pre-defined proposals in InstanceData for the second stage # to extract ROI features. rpn_results_list = [ data_sample.proposals for data_sample in batch_data_samples ] if self.sam_head: roi_losses = self.roi_head.loss(x, rpn_results_list, batch_data_samples, sam, img_seg_feat ) else: roi_losses = self.roi_head.loss(x, rpn_results_list, batch_data_samples ) losses.update(roi_losses) return losses def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, sam, rescale: bool = True ) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results. Defaults to True. Returns: list[:obj:`DetDataSample`]: Return the detection results of the input images. The returns value is DetDataSample, which usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ assert self.with_bbox, 'Bbox head must be implemented.' x = self.extract_feat(batch_inputs) img_seg_feat = batch_inputs[0] # If there are no pre-defined proposals, use RPN to get proposals if batch_data_samples[0].get('proposals', None) is None: rpn_results_list = self.rpn_head.predict( x, batch_data_samples, rescale=False) else: rpn_results_list = [ data_sample.proposals for data_sample in batch_data_samples ] if self.sam_head: results_list = self.roi_head.predict( x, rpn_results_list, batch_data_samples, sam, img_seg_feat, rescale=rescale) else: results_list = self.roi_head.predict( x, rpn_results_list, batch_data_samples, rescale=rescale) batch_data_samples = self.add_pred_to_datasample( batch_data_samples, results_list) return batch_data_samples @MODELS.register_module() class SAMAnchorPromptRoIHead(StandardRoIHead): def __init__( self, positional_encoding=dict(num_feats=128, normalize=True), *args, **kwargs ): super(StandardRoIHead, self).__init__(*args, **kwargs) self.generator_pe = SinePositionalEncoding(**positional_encoding) def _mask_forward(self, x: Tuple[Tensor], rois: Tensor = None, pos_inds: Optional[Tensor] = None, bbox_feats: Optional[Tensor] = None, sam=None, img_seg_feat=None ) -> dict: """Mask head forward function used in both training and testing. Args: x (tuple[Tensor]): Tuple of multi-level img features. rois (Tensor): RoIs with the shape (n, 5) where the first column indicates batch id of each RoI. pos_inds (Tensor, optional): Indices of positive samples. Defaults to None. bbox_feats (Tensor): Extract bbox RoI features. Defaults to None. Returns: dict[str, Tensor]: Usually returns a dictionary with keys: - `mask_preds` (Tensor): Mask prediction. - `mask_feats` (Tensor): Extract mask RoI features. """ assert ((rois is not None) ^ (pos_inds is not None and bbox_feats is not None)) if rois is not None: mask_feats = self.mask_roi_extractor( x[:self.mask_roi_extractor.num_inputs], rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) else: assert bbox_feats is not None mask_feats = bbox_feats[pos_inds] mask_preds = self.mask_head(mask_feats, sam, img_seg_feat, img_flag_ids=rois[:, 0]) mask_results = dict(mask_preds=mask_preds[0], mask_iou=mask_preds[1], mask_feats=mask_feats) return mask_results def mask_loss(self, x: Tuple[Tensor], sampling_results: List[SamplingResult], bbox_feats: Tensor, batch_gt_instances: InstanceList, sam, img_seg_feat ) -> dict: """Perform forward propagation and loss calculation of the mask head on the features of the upstream network. Args: x (tuple[Tensor]): Tuple of multi-level img features. sampling_results (list["obj:`SamplingResult`]): Sampling results. bbox_feats (Tensor): Extract bbox RoI features. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``labels``, and ``masks`` attributes. Returns: dict: Usually returns a dictionary with keys: - `mask_preds` (Tensor): Mask prediction. - `mask_feats` (Tensor): Extract mask RoI features. - `mask_targets` (Tensor): Mask target of each positive\ proposals in the image. - `loss_mask` (dict): A dictionary of mask loss components. """ if not self.share_roi_extractor: pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) mask_results = self._mask_forward( x, pos_rois, sam=sam, img_seg_feat=img_seg_feat) else: pos_inds = [] device = bbox_feats.device for res in sampling_results: pos_inds.append( torch.ones( res.pos_priors.shape[0], device=device, dtype=torch.uint8)) pos_inds.append( torch.zeros( res.neg_priors.shape[0], device=device, dtype=torch.uint8)) pos_inds = torch.cat(pos_inds) mask_results = self._mask_forward( x, pos_inds=pos_inds, bbox_feats=bbox_feats) mask_loss_and_target = self.mask_head.loss_and_target( mask_preds=mask_results['mask_preds'], sampling_results=sampling_results, batch_gt_instances=batch_gt_instances, rcnn_train_cfg=self.train_cfg) mask_results.update(loss_mask=mask_loss_and_target['loss_mask']) return mask_results def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: List[DetDataSample], sam, img_seg_feat ) -> dict: """Perform forward propagation and loss calculation of the detection roi on the features of the upstream network. Args: x (tuple[Tensor]): List of multi-level img features. rpn_results_list (list[:obj:`InstanceData`]): List of region proposals. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict[str, Tensor]: A dictionary of loss components """ x = list(x) bs, _, h, w = x[-1].shape mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) img_feats_pe = self.generator_pe(mask_pe) for i in range(len(x)): x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') assert len(rpn_results_list) == len(batch_data_samples) outputs = unpack_gt_instances(batch_data_samples) batch_gt_instances, batch_gt_instances_ignore, _ = outputs # assign gts and sample proposals num_imgs = len(batch_data_samples) sampling_results = [] for i in range(num_imgs): # rename rpn_results.bboxes to rpn_results.priors rpn_results = rpn_results_list[i] rpn_results.priors = rpn_results.pop('bboxes') assign_result = self.bbox_assigner.assign( rpn_results, batch_gt_instances[i], batch_gt_instances_ignore[i]) sampling_result = self.bbox_sampler.sample( assign_result, rpn_results, batch_gt_instances[i], feats=[lvl_feat[i][None] for lvl_feat in x]) sampling_results.append(sampling_result) losses = dict() # bbox head loss if self.with_bbox: bbox_results = self.bbox_loss(x, sampling_results) losses.update(bbox_results['loss_bbox']) # mask head forward and loss if self.with_mask: mask_results = self.mask_loss(x, sampling_results, bbox_results['bbox_feats'], batch_gt_instances, sam, img_seg_feat ) losses.update(mask_results['loss_mask']) return losses def predict_mask(self, x: Tuple[Tensor], batch_img_metas: List[dict], results_list: InstanceList, rescale: bool = False, sam=None, img_seg_feat=None ) -> InstanceList: """Perform forward propagation of the mask head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Feature maps of all scale level. batch_img_metas (list[dict]): List of image information. results_list (list[:obj:`InstanceData`]): Detection results of each image. rescale (bool): If True, return boxes in original image space. Defaults to False. Returns: list[:obj:`InstanceData`]: Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ # don't need to consider aug_test. bboxes = [res.bboxes for res in results_list] mask_rois = bbox2roi(bboxes) if mask_rois.shape[0] == 0: results_list = empty_instances( batch_img_metas, mask_rois.device, task_type='mask', instance_results=results_list, mask_thr_binary=self.test_cfg.mask_thr_binary) return results_list mask_results = self._mask_forward(x, mask_rois, sam=sam, img_seg_feat=img_seg_feat) mask_preds = mask_results['mask_preds'] # split batch mask prediction back to each image num_mask_rois_per_img = [len(res) for res in results_list] mask_preds = mask_preds.split(num_mask_rois_per_img, 0) # TODO: Handle the case where rescale is false results_list = self.mask_head.predict_by_feat( mask_preds=mask_preds, results_list=results_list, batch_img_metas=batch_img_metas, rcnn_test_cfg=self.test_cfg, rescale=rescale) return results_list def predict(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: SampleList, sam, img_seg_feat, rescale: bool = False) -> InstanceList: """Perform forward propagation of the roi head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Features from upstream network. Each has shape (N, C, H, W). rpn_results_list (list[:obj:`InstanceData`]): list of region proposals. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results to the original image. Defaults to True. Returns: list[obj:`InstanceData`]: Detection results of each image. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). """ x = list(x) bs, _, h, w = x[-1].shape mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) img_feats_pe = self.generator_pe(mask_pe) for i in range(len(x)): x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') assert self.with_bbox, 'Bbox head must be implemented.' batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] # TODO: nms_op in mmcv need be enhanced, the bbox result may get # difference when not rescale in bbox_head # If it has the mask branch, the bbox branch does not need # to be scaled to the original image scale, because the mask # branch will scale both bbox and mask at the same time. bbox_rescale = rescale if not self.with_mask else False results_list = self.predict_bbox( x, batch_img_metas, rpn_results_list, rcnn_test_cfg=self.test_cfg, rescale=bbox_rescale) if self.with_mask: results_list = self.predict_mask( x, batch_img_metas, results_list, rescale=rescale, sam=sam, img_seg_feat=img_seg_feat) return results_list @MODELS.register_module() class SAMPromptMaskHead(FCNMaskHead): def __init__(self, per_query_point: int = 5, with_sincos: bool = True, class_agnostic: bool = False, loss_mask: ConfigType = dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), *args, **kwargs ) -> None: super(BaseModule, self).__init__() self.per_query_point = per_query_point self.with_sincos = with_sincos self.class_agnostic = class_agnostic self.loss_mask = MODELS.build(loss_mask) if with_sincos: sincos = 2 else: sincos = 1 self.point_emb = nn.Sequential( nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(7*7*256, 256), nn.ReLU(inplace=True), nn.Linear(256, 256), nn.ReLU(inplace=True), nn.Linear(256, 256*sincos*per_query_point) ) def forward(self, x, sam, img_seg_feat, img_flag_ids) -> Tensor: batch_size = x.shape[0] point_emb = self.point_emb(x) point_emb = point_emb.view(batch_size, self.per_query_point, -1) if self.with_sincos: point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( point_emb.shape[0], -1, *img_seg_feat.shape[-2:] ) img_flag_ids = torch.bincount(img_flag_ids.long()) padding = torch.zeros((len(img_seg_feat)-len(img_flag_ids),), device=img_flag_ids.device, dtype=img_flag_ids.dtype) img_flag_ids = torch.cat([img_flag_ids, padding]) img_embeddings = torch.repeat_interleave(img_seg_feat, img_flag_ids, dim=0) img_pe = sam.prompt_encoder.get_dense_pe() img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) res_img_feat = None low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( image_embeddings=img_embeddings, image_pe=img_pe, sparse_prompt_embeddings=point_emb, dense_prompt_embeddings=nomask_dense_embeddings, multimask_output=False, res_img_feat=res_img_feat, ) mask_pred = low_res_masks.squeeze(1) iou_predictions = iou_predictions.squeeze(1) return mask_pred, iou_predictions def get_targets(self, sampling_results: List[SamplingResult], batch_gt_instances: InstanceList, rcnn_train_cfg: ConfigDict) -> Tensor: """Calculate the ground truth for all samples in a batch according to the sampling_results. Args: sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``labels``, and ``masks`` attributes. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. Returns: Tensor: Mask target of each positive proposals in the image. """ pos_proposals = [res.pos_priors for res in sampling_results] pos_assigned_gt_inds = [ res.pos_assigned_gt_inds for res in sampling_results ] gt_masks = [res.masks for res in batch_gt_instances] mask_targets_list = [] mask_size = (rcnn_train_cfg.mask_size,) * 2 device = pos_proposals[0].device for pos_gt_inds, gt_mask in zip(pos_assigned_gt_inds, gt_masks): if len(pos_gt_inds) == 0: mask_targets = torch.zeros((0,) + mask_size, device=device, dytpe=torch.float32) else: mask_targets = gt_mask[pos_gt_inds.cpu()].to_tensor(dtype=torch.float32, device=device) mask_targets_list.append(mask_targets) mask_targets = torch.cat(mask_targets_list) return mask_targets def loss_and_target(self, mask_preds: Tensor, sampling_results: List[SamplingResult], batch_gt_instances: InstanceList, rcnn_train_cfg: ConfigDict) -> dict: """Calculate the loss based on the features extracted by the mask head. Args: mask_preds (Tensor): Predicted foreground masks, has shape (num_pos, num_classes, h, w). sampling_results (List[obj:SamplingResult]): Assign results of all images in a batch after sampling. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes``, ``labels``, and ``masks`` attributes. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. Returns: dict: A dictionary of loss and targets components. """ mask_targets = self.get_targets( sampling_results=sampling_results, batch_gt_instances=batch_gt_instances, rcnn_train_cfg=rcnn_train_cfg) pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) mask_preds = torch.nn.functional.interpolate( mask_preds.unsqueeze(1), size=mask_targets.shape[-2:], mode='bilinear', align_corners=False) loss = dict() if mask_preds.size(0) == 0: loss_mask = mask_preds.sum() else: if self.class_agnostic: loss_mask = self.loss_mask(mask_preds, mask_targets, torch.zeros_like(pos_labels)) else: loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels) loss['loss_mask'] = loss_mask # TODO: which algorithm requires mask_targets? return dict(loss_mask=loss, mask_targets=mask_targets) def _predict_by_feat_single(self, mask_preds: Tensor, bboxes: Tensor, labels: Tensor, img_meta: dict, rcnn_test_cfg: ConfigDict, rescale: bool = False, activate_map: bool = False) -> Tensor: """Get segmentation masks from mask_preds and bboxes. Args: mask_preds (Tensor): Predicted foreground masks, has shape (n, num_classes, h, w). bboxes (Tensor): Predicted bboxes, has shape (n, 4) labels (Tensor): Labels of bboxes, has shape (n, ) img_meta (dict): image information. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. activate_map (book): Whether get results with augmentations test. If True, the `mask_preds` will not process with sigmoid. Defaults to False. Returns: Tensor: Encoded masks, has shape (n, img_w, img_h) Example: >>> from mmengine.config import Config >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA >>> N = 7 # N = number of extracted ROIs >>> C, H, W = 11, 32, 32 >>> # Create example instance of FCN Mask Head. >>> self = FCNMaskHead(num_classes=C, num_convs=0) >>> inputs = torch.rand(N, self.in_channels, H, W) >>> mask_preds = self.forward(inputs) >>> # Each input is associated with some bounding box >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) >>> labels = torch.randint(0, C, size=(N,)) >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) >>> ori_shape = (H * 4, W * 4) >>> scale_factor = (1, 1) >>> rescale = False >>> img_meta = {'scale_factor': scale_factor, ... 'ori_shape': ori_shape} >>> # Encoded masks are a list for each category. >>> encoded_masks = self._get_seg_masks_single( ... mask_preds, bboxes, labels, ... img_meta, rcnn_test_cfg, rescale) >>> assert encoded_masks.size()[0] == N >>> assert encoded_masks.size()[1:] == ori_shape """ scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( (1, 2)) img_h, img_w = img_meta['ori_shape'][:2] device = bboxes.device if not activate_map: mask_preds = mask_preds.sigmoid() else: # In AugTest, has been activated before mask_preds = bboxes.new_tensor(mask_preds) if rescale: # in-placed rescale the bboxes bboxes /= scale_factor else: w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] img_h = np.round(img_h * h_scale.item()).astype(np.int32) img_w = np.round(img_w * w_scale.item()).astype(np.int32) threshold = rcnn_test_cfg.mask_thr_binary im_mask = torch.nn.functional.interpolate( mask_preds.unsqueeze(1), size=(img_h, img_w), mode='bilinear', align_corners=False).squeeze(1) if threshold >= 0: im_mask = im_mask >= threshold else: # for visualization and debugging im_mask = (im_mask * 255).to(dtype=torch.uint8) return im_mask