File size: 6,512 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn

from .. import builder
from ..builder import POSENETS


@POSENETS.register_module()
class MultiTask(nn.Module):
    """Multi-task detectors.

    Args:
        backbone (dict): Backbone modules to extract feature.
        heads (list[dict]): heads to output predictions.
        necks (list[dict] | None): necks to process feature.
        head2neck (dict{int:int}): head index to neck index.
        pretrained (str): Path to the pretrained models.
    """

    def __init__(self,
                 backbone,
                 heads,
                 necks=None,
                 head2neck=None,
                 pretrained=None):
        super().__init__()

        self.backbone = builder.build_backbone(backbone)

        if head2neck is None:
            assert necks is None
            head2neck = {}

        self.head2neck = {}
        for i in range(len(heads)):
            self.head2neck[i] = head2neck[i] if i in head2neck else -1

        self.necks = nn.ModuleList([])
        if necks is not None:
            for neck in necks:
                self.necks.append(builder.build_neck(neck))
        self.necks.append(nn.Identity())

        self.heads = nn.ModuleList([])
        assert heads is not None
        for head in heads:
            assert head is not None
            self.heads.append(builder.build_head(head))

        self.init_weights(pretrained=pretrained)

    @property
    def with_necks(self):
        """Check if has keypoint_head."""
        return hasattr(self, 'necks')

    def init_weights(self, pretrained=None):
        """Weight initialization for model."""
        self.backbone.init_weights(pretrained)
        if self.with_necks:
            for neck in self.necks:
                if hasattr(neck, 'init_weights'):
                    neck.init_weights()

        for head in self.heads:
            if hasattr(head, 'init_weights'):
                head.init_weights()

    def forward(self,
                img,
                target=None,
                target_weight=None,
                img_metas=None,
                return_loss=True,
                **kwargs):
        """Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
        When `return_loss=True`, img and img_meta are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
        the outer list indicating test time augmentations.

        Note:
            - batch_size: N
            - num_keypoints: K
            - num_img_channel: C (Default: 3)
            - img height: imgH
            - img weight: imgW
            - heatmaps height: H
            - heatmaps weight: W

        Args:
            img (torch.Tensor[N,C,imgH,imgW]): Input images.
            target (list[torch.Tensor]): Targets.
            target_weight (List[torch.Tensor]): Weights.
            img_metas (list(dict)): Information about data augmentation
                By default this includes:

                - "image_file: path to the image file
                - "center": center of the bbox
                - "scale": scale of the bbox
                - "rotation": rotation of the bbox
                - "bbox_score": score of bbox
            return_loss (bool): Option to `return loss`. `return loss=True`
                for training, `return loss=False` for validation & test.

        Returns:
            dict|tuple: if `return loss` is true, then return losses. \
                Otherwise, return predicted poses, boxes, image paths \
                and heatmaps.
        """
        if return_loss:
            return self.forward_train(img, target, target_weight, img_metas,
                                      **kwargs)
        return self.forward_test(img, img_metas, **kwargs)

    def forward_train(self, img, target, target_weight, img_metas, **kwargs):
        """Defines the computation performed at every call when training."""
        features = self.backbone(img)
        outputs = []

        for head_id, head in enumerate(self.heads):
            neck_id = self.head2neck[head_id]
            outputs.append(head(self.necks[neck_id](features)))

        # if return loss
        losses = dict()

        for head, output, gt, gt_weight in zip(self.heads, outputs, target,
                                               target_weight):
            loss = head.get_loss(output, gt, gt_weight)
            assert len(set(losses.keys()).intersection(set(loss.keys()))) == 0
            losses.update(loss)

            if hasattr(head, 'get_accuracy'):
                acc = head.get_accuracy(output, gt, gt_weight)
                assert len(set(losses.keys()).intersection(set(
                    acc.keys()))) == 0
                losses.update(acc)

        return losses

    def forward_test(self, img, img_metas, **kwargs):
        """Defines the computation performed at every call when testing."""
        assert img.size(0) == len(img_metas)
        batch_size, _, img_height, img_width = img.shape
        if batch_size > 1:
            assert 'bbox_id' in img_metas[0]

        results = {}

        features = self.backbone(img)
        outputs = []

        for head_id, head in enumerate(self.heads):
            neck_id = self.head2neck[head_id]
            if hasattr(head, 'inference_model'):
                head_output = head.inference_model(
                    self.necks[neck_id](features), flip_pairs=None)
            else:
                head_output = head(
                    self.necks[neck_id](features)).detach().cpu().numpy()
            outputs.append(head_output)

        for head, output in zip(self.heads, outputs):
            result = head.decode(
                img_metas, output, img_size=[img_width, img_height])
            results.update(result)
        return results

    def forward_dummy(self, img):
        """Used for computing network FLOPs.

        See ``tools/get_flops.py``.

        Args:
            img (torch.Tensor): Input image.

        Returns:
            list[Tensor]: Outputs.
        """
        features = self.backbone(img)
        outputs = []
        for head_id, head in enumerate(self.heads):
            neck_id = self.head2neck[head_id]
            outputs.append(head(self.necks[neck_id](features)))
        return outputs