Spaces:
Paused
Paused
#!/usr/bin/env python3 | |
# Copyright 2004-present Facebook. All Rights Reserved. | |
from detectron2.config import configurable | |
from detectron2.utils.registry import Registry | |
from ..config.config import CfgNode as CfgNode_ | |
from ..structures import Instances | |
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") | |
TRACKER_HEADS_REGISTRY.__doc__ = """ | |
Registry for tracking classes. | |
""" | |
class BaseTracker: | |
""" | |
A parent class for all trackers | |
""" | |
def __init__(self, **kwargs): | |
self._prev_instances = None # (D2)instances for previous frame | |
self._matched_idx = set() # indices in prev_instances found matching | |
self._matched_ID = set() # idendities in prev_instances found matching | |
self._untracked_prev_idx = set() # indices in prev_instances not found matching | |
self._id_count = 0 # used to assign new id | |
def from_config(cls, cfg: CfgNode_): | |
raise NotImplementedError("Calling BaseTracker::from_config") | |
def update(self, predictions: Instances) -> Instances: | |
""" | |
Args: | |
predictions: D2 Instances for predictions of the current frame | |
Return: | |
D2 Instances for predictions of the current frame with ID assigned | |
_prev_instances and instances will have the following fields: | |
.pred_boxes (shape=[N, 4]) | |
.scores (shape=[N,]) | |
.pred_classes (shape=[N,]) | |
.pred_keypoints (shape=[N, M, 3], Optional) | |
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] | |
.ID (shape=[N,]) | |
N: # of detected bboxes | |
H and W: height and width of 2D mask | |
""" | |
raise NotImplementedError("Calling BaseTracker::update") | |
def build_tracker_head(cfg: CfgNode_) -> BaseTracker: | |
""" | |
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. | |
Args: | |
cfg: D2 CfgNode, config file with tracker information | |
Return: | |
tracker object | |
""" | |
name = cfg.TRACKER_HEADS.TRACKER_NAME | |
tracker_class = TRACKER_HEADS_REGISTRY.get(name) | |
return tracker_class(cfg) | |