File size: 21,362 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
"""
Author: Siyuan Li
Licensed: Apache-2.0 License
"""

import copy
import os
import pickle
import warnings
from typing import Dict, List, Optional, Tuple, Union

import torch
from mmdet.models.mot.base import BaseMOTModel
from mmdet.registry import MODELS
from mmdet.structures import TrackSampleList
from mmdet.utils import OptConfigType, OptMultiConfig
from mmengine.structures import InstanceData
from torch import Tensor


@MODELS.register_module()
class MASA(BaseMOTModel):

    """Matching Anything By Segmenting Anything.

    This multi object tracker is the implementation of `MASA
    https://arxiv.org/abs/2406.04221`.

    Args:
        backbone (dict, optional): Configuration of backbone. Defaults to None.
        detector (dict, optional): Configuration of detector. Defaults to None.
        masa_adapter (dict, optional): Configuration of MASA adapter. Defaults to None.
        rpn_head (dict, optional): Configuration of RPN head. Defaults to None.
        roi_head (dict, optional): Configuration of RoI head. Defaults to None.
        track_head (dict, optional): Configuration of track head. Defaults to None.
        tracker (dict, optional): Configuration of tracker. Defaults to None.
        freeze_detector (bool): If True, freeze the detector weights. Defaults to False.
        freeze_masa_backbone (bool): If True, freeze the MASA backbone weights. Defaults to False.
        freeze_masa_adapter (bool): If True, freeze the MASA adapter weights. Defaults to False.
        freeze_object_prior_distillation (bool): If True, freeze the object prior distillation. Defaults to False.
        data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`TrackDataPreprocessor`.
            It usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. Defaults to None.
        train_cfg (dict or ConfigDict, optional): Training configuration. Defaults to None.
        test_cfg (dict or ConfigDict, optional): Testing configuration. Defaults to None.
        init_cfg (dict or list[dict], optional): Configuration of initialization. Defaults to None.
        load_public_dets (bool): If True, load public detections. Defaults to False.
        public_det_path (str, optional): Path to public detections. Required if load_public_dets is True. Defaults to None.
        given_dets (bool): If True, detections are given. Defaults to False.
        with_segm (bool): If True, segmentation masks are included. Defaults to False.
        end_pkl_name (str): Suffix for pickle file names. Defaults to '.pth'.
        unified_backbone (bool): If True, use a unified backbone. Defaults to False.
        use_masa_backbone (bool): If True, use the MASA backbone. Defaults to False.
        benchmark (str): Benchmark for evaluation. Defaults to 'tao'.
    """

    def __init__(
        self,
        backbone: Optional[dict] = None,
        detector: Optional[dict] = None,
        masa_adapter: Optional[dict] = None,
        rpn_head: Optional[dict] = None,
        roi_head: Optional[dict] = None,
        track_head: Optional[dict] = None,
        tracker: Optional[dict] = None,
        freeze_detector: bool = False,
        freeze_masa_backbone: bool = False,
        freeze_masa_adapter: bool = False,
        freeze_object_prior_distillation: bool = False,
        data_preprocessor: OptConfigType = None,
        train_cfg: OptConfigType = None,
        test_cfg: OptConfigType = None,
        init_cfg: OptMultiConfig = None,
        load_public_dets=False,
        public_det_path=None,
        given_dets=False,
        with_segm=False,
        end_pkl_name=".pth",
        unified_backbone=False,
        use_masa_backbone=False,
        benchmark="tao",
    ) -> None:
        super().__init__(data_preprocessor, init_cfg)

        self.use_masa_backbone = use_masa_backbone
        if use_masa_backbone:
            assert (
                backbone is not None
            ), "backbone must be set when using MASA backbone."

        if backbone is not None:
            self.backbone = MODELS.build(backbone)

        if detector is not None:
            self.detector = MODELS.build(detector)

        if masa_adapter is not None:
            self.masa_adapter = MODELS.build(masa_adapter)

        if rpn_head is not None:
            rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
            rpn_head_ = rpn_head.copy()
            rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
            rpn_head_num_classes = rpn_head_.get("num_classes", None)
            if rpn_head_num_classes is None:
                rpn_head_.update(num_classes=1)
            else:
                if rpn_head_num_classes != 1:
                    warnings.warn(
                        "The `num_classes` should be 1 in RPN, but get "
                        f"{rpn_head_num_classes}, please set "
                        "rpn_head.num_classes = 1 in your config file."
                    )
                    rpn_head_.update(num_classes=1)
            self.rpn_head = MODELS.build(rpn_head_)

        if roi_head is not None:
            # update train and test cfg here for now
            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
            roi_head.update(train_cfg=rcnn_train_cfg)
            roi_head.update(test_cfg=test_cfg.rcnn)
            self.roi_head = MODELS.build(roi_head)

        if track_head is not None:
            self.track_head = MODELS.build(track_head)

        if tracker is not None:
            self.tracker = MODELS.build(tracker)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.freeze_detector = freeze_detector
        self.freeze_masa_adapter = freeze_masa_adapter
        self.freeze_object_prior_distillation = freeze_object_prior_distillation
        self.freeze_masa_backbone = freeze_masa_backbone

        def set_to_eval(module, input):
            module.eval()

        if self.freeze_detector:
            assert (
                detector is not None
            ), "detector must be set when freeze_detector is True."
            self.freeze_module("detector")
            # self.detector.backbone.register_forward_pre_hook(set_to_eval)

        if self.freeze_masa_adapter:
            assert (
                masa_adapter is not None
            ), "masa_adapter must be set when freeze_masa_adapter is True."
            self.freeze_module("masa_adapter")

            self.masa_adapter.register_forward_pre_hook(set_to_eval)

        if self.freeze_object_prior_distillation:
            assert (
                roi_head is not None
            ), "roi_head must be set when freeze_object_prior_distillation is True."
            assert (
                rpn_head is not None
            ), "rpn_head must be set when freeze_object_prior_distillation is True."
            self.freeze_module("roi_head")
            self.freeze_module("rpn_head")

        if self.freeze_masa_backbone:
            assert (
                backbone is not None
            ), "backbone must be set when freeze_masa_backbone is True."
            self.freeze_module("backbone")
            self.backbone.register_forward_pre_hook(set_to_eval)

        if load_public_dets:
            assert (
                public_det_path is not None
            ), "load_public_dets and public_det_path must be set together."
            self.benchmark = benchmark
        self.load_public_dets = load_public_dets
        self.public_det_path = public_det_path
        self.with_segm = with_segm
        self.end_pkl_name = end_pkl_name
        self.given_dets = given_dets

        self.unified_backbone = unified_backbone

    @property
    def with_rpn(self) -> bool:
        """bool: whether the detector has RPN"""
        return hasattr(self, "rpn_head") and self.rpn_head is not None

    @property
    def with_roi_head(self) -> bool:
        """bool: whether the detector has a RoI head"""
        return hasattr(self, "roi_head") and self.roi_head is not None

    def predict(
        self,
        inputs: Tensor,
        data_samples: TrackSampleList,
        rescale: bool = True,
        **kwargs,
    ) -> TrackSampleList:
        """Predict results from a video and data samples with post- processing.

        Args:
            inputs (Tensor): of shape (N, T, C, H, W) encoding
                input images. The N denotes batch size.
                The T denotes the number of frames in a video.
            data_samples (list[:obj:`TrackDataSample`]): The batch
                data samples. It usually includes information such
                as `video_data_samples`.
            rescale (bool, Optional): If False, then returned bboxes and masks
                will fit the scale of img, otherwise, returned bboxes and masks
                will fit the scale of original image shape. Defaults to True.

        Returns:
            TrackSampleList: Tracking results of the inputs.
        """
        assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)."
        assert (
            inputs.size(0) == 1
        ), "MASA inference only support 1 batch size per gpu for now."

        assert len(data_samples) == 1, "MASA only support 1 batch size per gpu for now."

        track_data_sample = data_samples[0]
        video_len = len(track_data_sample)
        if track_data_sample[0].frame_id == 0:
            self.tracker.reset()

        for frame_id in range(video_len):
            img_data_sample = track_data_sample[frame_id]
            single_img = inputs[:, frame_id].contiguous()
            if self.load_public_dets:
                img_name = img_data_sample.img_path
                if img_name is not None:
                    if self.benchmark == "bdd":
                        pickle_name = img_name.replace(
                            "data/bdd/bdd100k/images/track/val/", ""
                        ).replace(".jpg", self.end_pkl_name)
                    elif self.benchmark == "tao":
                        pickle_name = img_name.replace("data/tao/frames/", "").replace(
                            ".jpg", self.end_pkl_name
                        )

                path = os.path.join(self.public_det_path, pickle_name)
                pickle_res = pickle.load(open(path, "rb"))
                det_labels = torch.tensor(pickle_res["det_labels"]).to("cuda")
                det_bboxes = (
                    torch.tensor(pickle_res["det_bboxes"]).to("cuda").to(torch.float32)
                )
                if len(det_bboxes) != 0:
                    if det_bboxes.size(1) == 4:
                        det_bboxes = torch.cat(
                            [
                                det_bboxes,
                                torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device),
                            ],
                            dim=1,
                        )

                det_results = InstanceData()
                det_results.labels = det_labels
                det_results.bboxes = det_bboxes[:, :4]
                det_results.scores = det_bboxes[:, 4]

                if self.with_segm:
                    segm_results = pickle_res["det_masks"]
                    det_results.masks = segm_results

                img_data_sample.pred_instances = det_results

                if self.unified_backbone:
                    if hasattr(self.detector.backbone, "with_text_model"):
                        x = self.detector.backbone.forward_image(single_img)
                    elif self.detector.__class__.__name__ == "SamMasa":
                        x = self.detector.backbone.forward_base_multi_level(single_img)
                    else:
                        x = self.detector.backbone(single_img)
                elif self.use_masa_backbone:
                    x = self.backbone.forward(single_img)
                x_m = self.masa_adapter(x)

            elif self.given_dets:
                assert (
                    "det_bboxes" in img_data_sample
                ), "det_bboxes must be given when given_dets is True."
                assert (
                    "det_labels" in img_data_sample
                ), "det_labels must be given when given_dets is True."
                det_labels = img_data_sample.det_labels
                det_bboxes = img_data_sample.det_bboxes
                if len(det_bboxes) != 0:
                    if det_bboxes.size(1) == 4:
                        det_bboxes = torch.cat(
                            [
                                det_bboxes,
                                torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device),
                            ],
                            dim=1,
                        )
                det_results = InstanceData()
                det_results.labels = det_labels
                det_results.bboxes = det_bboxes[:, :4]
                det_results.scores = det_bboxes[:, 4]

                img_data_sample.pred_instances = det_results

                if self.unified_backbone:
                    if hasattr(self.detector.backbone, "with_text_model"):
                        x = self.detector.backbone.forward_image(single_img)
                    elif self.detector.__class__.__name__ == "SamMasa":
                        x = self.detector.backbone.forward_base_multi_level(single_img)
                    else:
                        x = self.detector.backbone(single_img)
                elif self.use_masa_backbone:
                    x = self.backbone.forward(single_img)
                x_m = self.masa_adapter(x)
            else:
                if self.unified_backbone:
                    if hasattr(self.detector.backbone, "with_text_model"):
                        texts = img_data_sample.texts
                        ## fix some inconsistency caused by the implementation of yolo-world and mmdet
                        if type(texts[0]) == list:
                            new_texts = [text[0] for text in texts]
                            del img_data_sample.texts
                            img_data_sample.set_field(
                                new_texts, "texts", field_type="metainfo"
                            )
                        (
                            backbone_feats,
                            img_feats,
                            text_feats,
                        ) = self.detector.extract_feat(single_img, [img_data_sample])
                        x_m = self.masa_adapter(backbone_feats)
                        img_data_sample = self.detector.predict(
                            single_img,
                            (img_feats, text_feats),
                            [img_data_sample],
                            rescale=rescale,
                        )[0]
                    else:
                        x = self.detector.backbone(single_img)
                        x_m = self.masa_adapter(x)
                        if self.detector.with_neck:
                            x = self.detector.neck(x)

                        img_data_sample = self.detector.predict(
                            single_img, x, [img_data_sample], rescale=rescale
                        )[0]
                else:
                    raise NotImplementedError

            frame_pred_track_instances = self.tracker.track(
                model=self,
                img=single_img,
                feats=x_m,
                data_sample=img_data_sample,
                with_segm=self.with_segm,
                **kwargs,
            )
            if self.with_segm:
                if frame_pred_track_instances.mask_inds is not None:
                    frame_pred_track_instances.masks = [
                        img_data_sample.pred_instances.masks[i]
                        for i in frame_pred_track_instances.mask_inds
                    ]

            img_data_sample.pred_track_instances = frame_pred_track_instances

        return [track_data_sample]

    def parse_tensors(self, tensor_tuple, key_ids, ref_ids):
        key_tensors = []
        ref_tensors = []
        device = tensor_tuple[0].device
        for tensor in tensor_tuple:
            key_tensors.append(
                tensor.index_select(
                    0, torch.tensor(key_ids, dtype=torch.long, device=device)
                )
            )
            ref_tensors.append(
                tensor.index_select(
                    0, torch.tensor(ref_ids, dtype=torch.long, device=device)
                )
            )

        return list(key_tensors), list(ref_tensors)

    def loss(
        self, inputs: Tensor, data_samples: TrackSampleList, **kwargs
    ) -> Union[dict, tuple]:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding
                input images. Typically these should be mean centered and std
                scaled. The N denotes batch size. The T denotes the number of
                frames.
            data_samples (list[:obj:`TrackDataSample`]): The batch
                data samples. It usually includes information such
                as `video_data_samples`.

        Returns:
            dict: A dictionary of loss components.
        """
        # modify the inputs shape to fit mmdet
        assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)."
        assert (
            inputs.size(1) == 2
        ), "MASA can only have 1 key frame and 1 reference frame."
        if self.detector is not None:
            self.detector.eval()
        # split the data_samples into two aspects: key frames and reference
        # frames
        ref_data_samples, key_data_samples = [], []
        key_frame_inds, ref_frame_inds = [], []
        # set cat_id of gt_labels to 0 in RPN
        for track_data_sample in data_samples:
            key_frame_inds.append(track_data_sample.key_frames_inds[0])
            ref_frame_inds.append(track_data_sample.ref_frames_inds[0])
            key_data_sample = track_data_sample.get_key_frames()[0]
            key_data_sample.gt_instances.labels = torch.zeros_like(
                key_data_sample.gt_instances.labels
            )
            key_data_samples.append(key_data_sample)
            ref_data_sample = track_data_sample.get_ref_frames()[0]
            ref_data_samples.append(ref_data_sample)

        key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64)
        ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64)
        batch_inds = torch.arange(len(inputs))
        key_imgs = inputs[batch_inds, key_frame_inds].contiguous()
        ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous()

        if self.use_masa_backbone:
            x = self.backbone.forward(key_imgs)
            ref_x = self.backbone.forward(ref_imgs)

        else:
            if hasattr(self.detector.backbone, "with_text_model"):
                x = self.detector.backbone.forward_image(key_imgs)
                ref_x = self.detector.backbone.forward_image(ref_imgs)
            elif self.detector.__class__.__name__ == "SamMasa":
                x = self.detector.backbone.forward_base_multi_level(key_imgs)
                ref_x = self.detector.backbone.forward_base_multi_level(ref_imgs)
            else:
                x = self.detector.backbone.forward(key_imgs)
                ref_x = self.detector.backbone.forward(ref_imgs)

        x_m = self.masa_adapter(x)
        ref_x_m = self.masa_adapter(ref_x)

        losses = dict()

        if self.with_rpn:
            proposal_cfg = self.train_cfg.get("rpn_proposal", self.test_cfg.rpn)
            key_rpn_data_samples = copy.deepcopy(key_data_samples)
            ref_rpn_data_samples = copy.deepcopy(ref_data_samples)
            # set cat_id of gt_labels to 0 in RPN
            for data_sample in key_rpn_data_samples:
                data_sample.gt_instances.labels = torch.zeros_like(
                    data_sample.gt_instances.labels
                )
            for data_sample in ref_rpn_data_samples:
                data_sample.gt_instances.labels = torch.zeros_like(
                    data_sample.gt_instances.labels
                )

            rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict(
                x_m, key_rpn_data_samples, proposal_cfg=proposal_cfg, **kwargs
            )
            ref_rpn_results_list = self.rpn_head.predict(
                ref_x_m, ref_rpn_data_samples, **kwargs
            )

            # avoid get same name with roi_head loss
            keys = rpn_losses.keys()
            for key in keys:
                if "loss" in key and "rpn" not in key:
                    rpn_losses[f"rpn_{key}"] = rpn_losses.pop(key)
            losses.update(rpn_losses)
        else:
            raise NotImplementedError("MASA  only support with_rpn for now.")

        # roi_head loss
        losses_detect = self.roi_head.loss(
            x_m, rpn_results_list, key_data_samples, **kwargs
        )
        losses.update(losses_detect)

        # tracking head loss
        losses_track = self.track_head.loss(
            x_m, ref_x_m, rpn_results_list, ref_rpn_results_list, data_samples, **kwargs
        )
        losses.update(losses_track)

        return losses