Spaces:
Runtime error
Runtime error
""" | |
Author: Siyuan Li | |
Licensed: Apache-2.0 License | |
""" | |
import copy | |
import os | |
import pickle | |
import warnings | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
from mmdet.models.mot.base import BaseMOTModel | |
from mmdet.registry import MODELS | |
from mmdet.structures import TrackSampleList | |
from mmdet.utils import OptConfigType, OptMultiConfig | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
class MASA(BaseMOTModel): | |
"""Matching Anything By Segmenting Anything. | |
This multi object tracker is the implementation of `MASA | |
https://arxiv.org/abs/2406.04221`. | |
Args: | |
backbone (dict, optional): Configuration of backbone. Defaults to None. | |
detector (dict, optional): Configuration of detector. Defaults to None. | |
masa_adapter (dict, optional): Configuration of MASA adapter. Defaults to None. | |
rpn_head (dict, optional): Configuration of RPN head. Defaults to None. | |
roi_head (dict, optional): Configuration of RoI head. Defaults to None. | |
track_head (dict, optional): Configuration of track head. Defaults to None. | |
tracker (dict, optional): Configuration of tracker. Defaults to None. | |
freeze_detector (bool): If True, freeze the detector weights. Defaults to False. | |
freeze_masa_backbone (bool): If True, freeze the MASA backbone weights. Defaults to False. | |
freeze_masa_adapter (bool): If True, freeze the MASA adapter weights. Defaults to False. | |
freeze_object_prior_distillation (bool): If True, freeze the object prior distillation. Defaults to False. | |
data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`TrackDataPreprocessor`. | |
It usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. Defaults to None. | |
train_cfg (dict or ConfigDict, optional): Training configuration. Defaults to None. | |
test_cfg (dict or ConfigDict, optional): Testing configuration. Defaults to None. | |
init_cfg (dict or list[dict], optional): Configuration of initialization. Defaults to None. | |
load_public_dets (bool): If True, load public detections. Defaults to False. | |
public_det_path (str, optional): Path to public detections. Required if load_public_dets is True. Defaults to None. | |
given_dets (bool): If True, detections are given. Defaults to False. | |
with_segm (bool): If True, segmentation masks are included. Defaults to False. | |
end_pkl_name (str): Suffix for pickle file names. Defaults to '.pth'. | |
unified_backbone (bool): If True, use a unified backbone. Defaults to False. | |
use_masa_backbone (bool): If True, use the MASA backbone. Defaults to False. | |
benchmark (str): Benchmark for evaluation. Defaults to 'tao'. | |
""" | |
def __init__( | |
self, | |
backbone: Optional[dict] = None, | |
detector: Optional[dict] = None, | |
masa_adapter: Optional[dict] = None, | |
rpn_head: Optional[dict] = None, | |
roi_head: Optional[dict] = None, | |
track_head: Optional[dict] = None, | |
tracker: Optional[dict] = None, | |
freeze_detector: bool = False, | |
freeze_masa_backbone: bool = False, | |
freeze_masa_adapter: bool = False, | |
freeze_object_prior_distillation: bool = False, | |
data_preprocessor: OptConfigType = None, | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
load_public_dets=False, | |
public_det_path=None, | |
given_dets=False, | |
with_segm=False, | |
end_pkl_name=".pth", | |
unified_backbone=False, | |
use_masa_backbone=False, | |
benchmark="tao", | |
) -> None: | |
super().__init__(data_preprocessor, init_cfg) | |
self.use_masa_backbone = use_masa_backbone | |
if use_masa_backbone: | |
assert ( | |
backbone is not None | |
), "backbone must be set when using MASA backbone." | |
if backbone is not None: | |
self.backbone = MODELS.build(backbone) | |
if detector is not None: | |
self.detector = MODELS.build(detector) | |
if masa_adapter is not None: | |
self.masa_adapter = MODELS.build(masa_adapter) | |
if rpn_head is not None: | |
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None | |
rpn_head_ = rpn_head.copy() | |
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) | |
rpn_head_num_classes = rpn_head_.get("num_classes", None) | |
if rpn_head_num_classes is None: | |
rpn_head_.update(num_classes=1) | |
else: | |
if rpn_head_num_classes != 1: | |
warnings.warn( | |
"The `num_classes` should be 1 in RPN, but get " | |
f"{rpn_head_num_classes}, please set " | |
"rpn_head.num_classes = 1 in your config file." | |
) | |
rpn_head_.update(num_classes=1) | |
self.rpn_head = MODELS.build(rpn_head_) | |
if roi_head is not None: | |
# update train and test cfg here for now | |
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None | |
roi_head.update(train_cfg=rcnn_train_cfg) | |
roi_head.update(test_cfg=test_cfg.rcnn) | |
self.roi_head = MODELS.build(roi_head) | |
if track_head is not None: | |
self.track_head = MODELS.build(track_head) | |
if tracker is not None: | |
self.tracker = MODELS.build(tracker) | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.freeze_detector = freeze_detector | |
self.freeze_masa_adapter = freeze_masa_adapter | |
self.freeze_object_prior_distillation = freeze_object_prior_distillation | |
self.freeze_masa_backbone = freeze_masa_backbone | |
def set_to_eval(module, input): | |
module.eval() | |
if self.freeze_detector: | |
assert ( | |
detector is not None | |
), "detector must be set when freeze_detector is True." | |
self.freeze_module("detector") | |
# self.detector.backbone.register_forward_pre_hook(set_to_eval) | |
if self.freeze_masa_adapter: | |
assert ( | |
masa_adapter is not None | |
), "masa_adapter must be set when freeze_masa_adapter is True." | |
self.freeze_module("masa_adapter") | |
self.masa_adapter.register_forward_pre_hook(set_to_eval) | |
if self.freeze_object_prior_distillation: | |
assert ( | |
roi_head is not None | |
), "roi_head must be set when freeze_object_prior_distillation is True." | |
assert ( | |
rpn_head is not None | |
), "rpn_head must be set when freeze_object_prior_distillation is True." | |
self.freeze_module("roi_head") | |
self.freeze_module("rpn_head") | |
if self.freeze_masa_backbone: | |
assert ( | |
backbone is not None | |
), "backbone must be set when freeze_masa_backbone is True." | |
self.freeze_module("backbone") | |
self.backbone.register_forward_pre_hook(set_to_eval) | |
if load_public_dets: | |
assert ( | |
public_det_path is not None | |
), "load_public_dets and public_det_path must be set together." | |
self.benchmark = benchmark | |
self.load_public_dets = load_public_dets | |
self.public_det_path = public_det_path | |
self.with_segm = with_segm | |
self.end_pkl_name = end_pkl_name | |
self.given_dets = given_dets | |
self.unified_backbone = unified_backbone | |
def with_rpn(self) -> bool: | |
"""bool: whether the detector has RPN""" | |
return hasattr(self, "rpn_head") and self.rpn_head is not None | |
def with_roi_head(self) -> bool: | |
"""bool: whether the detector has a RoI head""" | |
return hasattr(self, "roi_head") and self.roi_head is not None | |
def predict( | |
self, | |
inputs: Tensor, | |
data_samples: TrackSampleList, | |
rescale: bool = True, | |
**kwargs, | |
) -> TrackSampleList: | |
"""Predict results from a video and data samples with post- processing. | |
Args: | |
inputs (Tensor): of shape (N, T, C, H, W) encoding | |
input images. The N denotes batch size. | |
The T denotes the number of frames in a video. | |
data_samples (list[:obj:`TrackDataSample`]): The batch | |
data samples. It usually includes information such | |
as `video_data_samples`. | |
rescale (bool, Optional): If False, then returned bboxes and masks | |
will fit the scale of img, otherwise, returned bboxes and masks | |
will fit the scale of original image shape. Defaults to True. | |
Returns: | |
TrackSampleList: Tracking results of the inputs. | |
""" | |
assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)." | |
assert ( | |
inputs.size(0) == 1 | |
), "MASA inference only support 1 batch size per gpu for now." | |
assert len(data_samples) == 1, "MASA only support 1 batch size per gpu for now." | |
track_data_sample = data_samples[0] | |
video_len = len(track_data_sample) | |
if track_data_sample[0].frame_id == 0: | |
self.tracker.reset() | |
for frame_id in range(video_len): | |
img_data_sample = track_data_sample[frame_id] | |
single_img = inputs[:, frame_id].contiguous() | |
if self.load_public_dets: | |
img_name = img_data_sample.img_path | |
if img_name is not None: | |
if self.benchmark == "bdd": | |
pickle_name = img_name.replace( | |
"data/bdd/bdd100k/images/track/val/", "" | |
).replace(".jpg", self.end_pkl_name) | |
elif self.benchmark == "tao": | |
pickle_name = img_name.replace("data/tao/frames/", "").replace( | |
".jpg", self.end_pkl_name | |
) | |
path = os.path.join(self.public_det_path, pickle_name) | |
pickle_res = pickle.load(open(path, "rb")) | |
det_labels = torch.tensor(pickle_res["det_labels"]).to("cuda") | |
det_bboxes = ( | |
torch.tensor(pickle_res["det_bboxes"]).to("cuda").to(torch.float32) | |
) | |
if len(det_bboxes) != 0: | |
if det_bboxes.size(1) == 4: | |
det_bboxes = torch.cat( | |
[ | |
det_bboxes, | |
torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device), | |
], | |
dim=1, | |
) | |
det_results = InstanceData() | |
det_results.labels = det_labels | |
det_results.bboxes = det_bboxes[:, :4] | |
det_results.scores = det_bboxes[:, 4] | |
if self.with_segm: | |
segm_results = pickle_res["det_masks"] | |
det_results.masks = segm_results | |
img_data_sample.pred_instances = det_results | |
if self.unified_backbone: | |
if hasattr(self.detector.backbone, "with_text_model"): | |
x = self.detector.backbone.forward_image(single_img) | |
elif self.detector.__class__.__name__ == "SamMasa": | |
x = self.detector.backbone.forward_base_multi_level(single_img) | |
else: | |
x = self.detector.backbone(single_img) | |
elif self.use_masa_backbone: | |
x = self.backbone.forward(single_img) | |
x_m = self.masa_adapter(x) | |
elif self.given_dets: | |
assert ( | |
"det_bboxes" in img_data_sample | |
), "det_bboxes must be given when given_dets is True." | |
assert ( | |
"det_labels" in img_data_sample | |
), "det_labels must be given when given_dets is True." | |
det_labels = img_data_sample.det_labels | |
det_bboxes = img_data_sample.det_bboxes | |
if len(det_bboxes) != 0: | |
if det_bboxes.size(1) == 4: | |
det_bboxes = torch.cat( | |
[ | |
det_bboxes, | |
torch.ones(det_bboxes.size(0), 1).to(det_bboxes.device), | |
], | |
dim=1, | |
) | |
det_results = InstanceData() | |
det_results.labels = det_labels | |
det_results.bboxes = det_bboxes[:, :4] | |
det_results.scores = det_bboxes[:, 4] | |
img_data_sample.pred_instances = det_results | |
if self.unified_backbone: | |
if hasattr(self.detector.backbone, "with_text_model"): | |
x = self.detector.backbone.forward_image(single_img) | |
elif self.detector.__class__.__name__ == "SamMasa": | |
x = self.detector.backbone.forward_base_multi_level(single_img) | |
else: | |
x = self.detector.backbone(single_img) | |
elif self.use_masa_backbone: | |
x = self.backbone.forward(single_img) | |
x_m = self.masa_adapter(x) | |
else: | |
if self.unified_backbone: | |
if hasattr(self.detector.backbone, "with_text_model"): | |
texts = img_data_sample.texts | |
## fix some inconsistency caused by the implementation of yolo-world and mmdet | |
if type(texts[0]) == list: | |
new_texts = [text[0] for text in texts] | |
del img_data_sample.texts | |
img_data_sample.set_field( | |
new_texts, "texts", field_type="metainfo" | |
) | |
( | |
backbone_feats, | |
img_feats, | |
text_feats, | |
) = self.detector.extract_feat(single_img, [img_data_sample]) | |
x_m = self.masa_adapter(backbone_feats) | |
img_data_sample = self.detector.predict( | |
single_img, | |
(img_feats, text_feats), | |
[img_data_sample], | |
rescale=rescale, | |
)[0] | |
else: | |
x = self.detector.backbone(single_img) | |
x_m = self.masa_adapter(x) | |
if self.detector.with_neck: | |
x = self.detector.neck(x) | |
img_data_sample = self.detector.predict( | |
single_img, x, [img_data_sample], rescale=rescale | |
)[0] | |
else: | |
raise NotImplementedError | |
frame_pred_track_instances = self.tracker.track( | |
model=self, | |
img=single_img, | |
feats=x_m, | |
data_sample=img_data_sample, | |
with_segm=self.with_segm, | |
**kwargs, | |
) | |
if self.with_segm: | |
if frame_pred_track_instances.mask_inds is not None: | |
frame_pred_track_instances.masks = [ | |
img_data_sample.pred_instances.masks[i] | |
for i in frame_pred_track_instances.mask_inds | |
] | |
img_data_sample.pred_track_instances = frame_pred_track_instances | |
return [track_data_sample] | |
def parse_tensors(self, tensor_tuple, key_ids, ref_ids): | |
key_tensors = [] | |
ref_tensors = [] | |
device = tensor_tuple[0].device | |
for tensor in tensor_tuple: | |
key_tensors.append( | |
tensor.index_select( | |
0, torch.tensor(key_ids, dtype=torch.long, device=device) | |
) | |
) | |
ref_tensors.append( | |
tensor.index_select( | |
0, torch.tensor(ref_ids, dtype=torch.long, device=device) | |
) | |
) | |
return list(key_tensors), list(ref_tensors) | |
def loss( | |
self, inputs: Tensor, data_samples: TrackSampleList, **kwargs | |
) -> Union[dict, tuple]: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding | |
input images. Typically these should be mean centered and std | |
scaled. The N denotes batch size. The T denotes the number of | |
frames. | |
data_samples (list[:obj:`TrackDataSample`]): The batch | |
data samples. It usually includes information such | |
as `video_data_samples`. | |
Returns: | |
dict: A dictionary of loss components. | |
""" | |
# modify the inputs shape to fit mmdet | |
assert inputs.dim() == 5, "The img must be 5D Tensor (N, T, C, H, W)." | |
assert ( | |
inputs.size(1) == 2 | |
), "MASA can only have 1 key frame and 1 reference frame." | |
if self.detector is not None: | |
self.detector.eval() | |
# split the data_samples into two aspects: key frames and reference | |
# frames | |
ref_data_samples, key_data_samples = [], [] | |
key_frame_inds, ref_frame_inds = [], [] | |
# set cat_id of gt_labels to 0 in RPN | |
for track_data_sample in data_samples: | |
key_frame_inds.append(track_data_sample.key_frames_inds[0]) | |
ref_frame_inds.append(track_data_sample.ref_frames_inds[0]) | |
key_data_sample = track_data_sample.get_key_frames()[0] | |
key_data_sample.gt_instances.labels = torch.zeros_like( | |
key_data_sample.gt_instances.labels | |
) | |
key_data_samples.append(key_data_sample) | |
ref_data_sample = track_data_sample.get_ref_frames()[0] | |
ref_data_samples.append(ref_data_sample) | |
key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64) | |
ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64) | |
batch_inds = torch.arange(len(inputs)) | |
key_imgs = inputs[batch_inds, key_frame_inds].contiguous() | |
ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous() | |
if self.use_masa_backbone: | |
x = self.backbone.forward(key_imgs) | |
ref_x = self.backbone.forward(ref_imgs) | |
else: | |
if hasattr(self.detector.backbone, "with_text_model"): | |
x = self.detector.backbone.forward_image(key_imgs) | |
ref_x = self.detector.backbone.forward_image(ref_imgs) | |
elif self.detector.__class__.__name__ == "SamMasa": | |
x = self.detector.backbone.forward_base_multi_level(key_imgs) | |
ref_x = self.detector.backbone.forward_base_multi_level(ref_imgs) | |
else: | |
x = self.detector.backbone.forward(key_imgs) | |
ref_x = self.detector.backbone.forward(ref_imgs) | |
x_m = self.masa_adapter(x) | |
ref_x_m = self.masa_adapter(ref_x) | |
losses = dict() | |
if self.with_rpn: | |
proposal_cfg = self.train_cfg.get("rpn_proposal", self.test_cfg.rpn) | |
key_rpn_data_samples = copy.deepcopy(key_data_samples) | |
ref_rpn_data_samples = copy.deepcopy(ref_data_samples) | |
# set cat_id of gt_labels to 0 in RPN | |
for data_sample in key_rpn_data_samples: | |
data_sample.gt_instances.labels = torch.zeros_like( | |
data_sample.gt_instances.labels | |
) | |
for data_sample in ref_rpn_data_samples: | |
data_sample.gt_instances.labels = torch.zeros_like( | |
data_sample.gt_instances.labels | |
) | |
rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( | |
x_m, key_rpn_data_samples, proposal_cfg=proposal_cfg, **kwargs | |
) | |
ref_rpn_results_list = self.rpn_head.predict( | |
ref_x_m, ref_rpn_data_samples, **kwargs | |
) | |
# avoid get same name with roi_head loss | |
keys = rpn_losses.keys() | |
for key in keys: | |
if "loss" in key and "rpn" not in key: | |
rpn_losses[f"rpn_{key}"] = rpn_losses.pop(key) | |
losses.update(rpn_losses) | |
else: | |
raise NotImplementedError("MASA only support with_rpn for now.") | |
# roi_head loss | |
losses_detect = self.roi_head.loss( | |
x_m, rpn_results_list, key_data_samples, **kwargs | |
) | |
losses.update(losses_detect) | |
# tracking head loss | |
losses_track = self.track_head.loss( | |
x_m, ref_x_m, rpn_results_list, ref_rpn_results_list, data_samples, **kwargs | |
) | |
losses.update(losses_track) | |
return losses | |