|
|
|
from typing import Tuple |
|
|
|
from torch import Tensor |
|
|
|
from mmdet.registry import MODELS |
|
from .standard_roi_head import StandardRoIHead |
|
|
|
|
|
@MODELS.register_module() |
|
class DoubleHeadRoIHead(StandardRoIHead): |
|
"""RoI head for `Double Head RCNN <https://arxiv.org/abs/1904.06493>`_. |
|
|
|
Args: |
|
reg_roi_scale_factor (float): The scale factor to extend the rois |
|
used to extract the regression features. |
|
""" |
|
|
|
def __init__(self, reg_roi_scale_factor: float, **kwargs): |
|
super().__init__(**kwargs) |
|
self.reg_roi_scale_factor = reg_roi_scale_factor |
|
|
|
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: |
|
"""Box head forward function used in both training and testing. |
|
|
|
Args: |
|
x (tuple[Tensor]): List of multi-level img features. |
|
rois (Tensor): RoIs with the shape (n, 5) where the first |
|
column indicates batch id of each RoI. |
|
|
|
Returns: |
|
dict[str, Tensor]: Usually returns a dictionary with keys: |
|
|
|
- `cls_score` (Tensor): Classification scores. |
|
- `bbox_pred` (Tensor): Box energies / deltas. |
|
- `bbox_feats` (Tensor): Extract bbox RoI features. |
|
""" |
|
bbox_cls_feats = self.bbox_roi_extractor( |
|
x[:self.bbox_roi_extractor.num_inputs], rois) |
|
bbox_reg_feats = self.bbox_roi_extractor( |
|
x[:self.bbox_roi_extractor.num_inputs], |
|
rois, |
|
roi_scale_factor=self.reg_roi_scale_factor) |
|
if self.with_shared_head: |
|
bbox_cls_feats = self.shared_head(bbox_cls_feats) |
|
bbox_reg_feats = self.shared_head(bbox_reg_feats) |
|
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) |
|
|
|
bbox_results = dict( |
|
cls_score=cls_score, |
|
bbox_pred=bbox_pred, |
|
bbox_feats=bbox_cls_feats) |
|
return bbox_results |
|
|