File size: 2,994 Bytes
9cc3eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch.nn.functional as F
from mmengine.model import BaseModel

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig


@MODELS.register_module()
class SAMSegmentor(BaseModel):
    MASK_THRESHOLD = 0.5

    def __init__(
            self,
            backbone: ConfigType,
            neck: ConfigType,
            prompt_encoder: ConfigType,
            mask_decoder: ConfigType,
            data_preprocessor: OptConfigType = None,
            fpn_neck: OptConfigType = None,
            init_cfg: OptMultiConfig = None,
            use_clip_feat: bool = False,
            use_head_feat: bool = False,
            use_gt_prompt: bool = False,
            use_point: bool = False,
            enable_backbone: bool = False,
    ) -> None:
        super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg)

        self.backbone = MODELS.build(backbone)
        self.neck = MODELS.build(neck)
        self.pe = MODELS.build(prompt_encoder)
        self.mask_decoder = MODELS.build(mask_decoder)
        if fpn_neck is not None:
            self.fpn_neck = MODELS.build(fpn_neck)
        else:
            self.fpn_neck = None

        self.use_clip_feat = use_clip_feat
        self.use_head_feat = use_head_feat
        self.use_gt_prompt = use_gt_prompt
        self.use_point = use_point

        self.enable_backbone = enable_backbone

    def extract_feat(self, inputs):
        backbone_feat = self.backbone(inputs)
        neck_feat = self.neck(backbone_feat)
        if self.fpn_neck is not None:
            fpn_feat = self.fpn_neck(backbone_feat)
        else:
            fpn_feat = None

        return dict(
            backbone_feat=backbone_feat,
            neck_feat=neck_feat,
            fpn_feat=fpn_feat
        )

    def extract_masks(self, feat_cache, prompts):
        sparse_embed, dense_embed = self.pe(
            prompts,
            image_size=(1024, 1024),
            with_points='point_coords' in prompts,
            with_bboxes='bboxes' in prompts,
        )

        kwargs = dict()
        if self.enable_backbone:
            kwargs['backbone_feats'] = feat_cache['backbone_feat']
            kwargs['backbone'] = self.backbone
            kwargs['fpn_feats'] = feat_cache['fpn_feat']
        low_res_masks, iou_predictions, cls_pred = self.mask_decoder(
            image_embeddings=feat_cache['neck_feat'],
            image_pe=self.pe.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embed,
            dense_prompt_embeddings=dense_embed,
            multi_mask_output=False,
            **kwargs
        )
        masks = F.interpolate(
            low_res_masks,
            scale_factor=4.,
            mode='bilinear',
            align_corners=False,
        )

        masks = masks.sigmoid()
        cls_pred = cls_pred.softmax(-1)[..., :-1]
        return masks.detach().cpu().numpy(), cls_pred.detach().cpu()

    def forward(self, inputs):
        return inputs