File size: 4,576 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
from detectron2.data.transforms import ResizeShortestEdge
from detectron2.structures import Instances
from detectron2 import model_zoo
from detectron2.config import instantiate
from detectron2.config import LazyCall as L
from PIL import Image
import tops
import functools
from torchvision.transforms.functional import resize


def get_rn50_fpn_keypoint_rcnn(weight_path: str):
    from detectron2.modeling.poolers import ROIPooler
    from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
    from detectron2.layers import ShapeSpec
    model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
    model.roi_heads.update(
        num_classes=1,
        keypoint_in_features=["p2", "p3", "p4", "p5"],
        keypoint_pooler=L(ROIPooler)(
            output_size=14,
            scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
            sampling_ratio=0,
            pooler_type="ROIAlignV2",
        ),
        keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
            input_shape=ShapeSpec(channels=256, width=14, height=14),
            num_keypoints=17,
            conv_dims=[512] * 8,
            loss_normalizer="visible",
        ),
    )

    # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
    # 1000 proposals per-image is found to hurt box AP.
    # Therefore we increase it to 1500 per-image.
    model.proposal_generator.post_nms_topk = (1500, 1000)

    # Keypoint AP degrades (though box AP improves) when using plain L1 loss
    model.roi_heads.box_predictor.smooth_l1_beta = 0.5
    model = instantiate(model)

    dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
    test_transform = instantiate(dataloader.test.mapper.augmentations)
    DetectionCheckpointer(model).load(weight_path)
    return model, test_transform


models = {
    "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
}


class KeypointMaskRCNN:

    def __init__(self, model_name: str, score_threshold: float) -> None:
        assert model_name in models, f"Did not find {model_name} in models"
        model, test_transform = models[model_name]()
        self.model = model.eval().to(tops.get_device())
        if isinstance(self.model.roi_heads, CascadeROIHeads):
            for head in self.model.roi_heads.box_predictors:
                assert hasattr(head, "test_score_thresh")
                head.test_score_thresh = score_threshold
        else:
            assert isinstance(self.model.roi_heads, StandardROIHeads)
            assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
            self.model.roi_heads.box_predictor.test_score_thresh = score_threshold

        self.test_transform = test_transform
        assert len(self.test_transform) == 1
        self.test_transform = self.test_transform[0]
        assert isinstance(self.test_transform, ResizeShortestEdge)
        assert self.test_transform.interp == Image.BILINEAR
        self.image_format = self.model.input_format

    def resize_im(self, im):
        H, W = im.shape[-2:]
        if self.test_transform.is_range:
            size = np.random.randint(
                self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
        else:
            size = np.random.choice(self.test_transform.short_edge_length)
        newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
        return resize(
            im, (newH, newW), antialias=True)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @torch.no_grad()
    def forward(self, im: torch.Tensor):
        assert im.ndim == 3
        if self.image_format == "BGR":
            im = im.flip(0)
        H, W = im.shape[-2:]
        im = im.float()
        im = self.resize_im(im)

        inputs = dict(image=im, height=H, width=W)
        # instances contains
        # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
        instances = self.model([inputs])[0]["instances"]
        return dict(
            scores=instances.get("scores").cpu(),
            segmentation=instances.get("pred_masks").cpu(),
            keypoints=instances.get("pred_keypoints").cpu()
        )