|
|
|
|
|
from detectron2.config import configurable |
|
from detectron2.utils.registry import Registry |
|
|
|
from ..config.config import CfgNode as CfgNode_ |
|
from ..structures import Instances |
|
|
|
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") |
|
TRACKER_HEADS_REGISTRY.__doc__ = """ |
|
Registry for tracking classes. |
|
""" |
|
|
|
|
|
class BaseTracker: |
|
""" |
|
A parent class for all trackers |
|
""" |
|
|
|
@configurable |
|
def __init__(self, **kwargs): |
|
self._prev_instances = None |
|
self._matched_idx = set() |
|
self._matched_ID = set() |
|
self._untracked_prev_idx = set() |
|
self._id_count = 0 |
|
|
|
@classmethod |
|
def from_config(cls, cfg: CfgNode_): |
|
raise NotImplementedError("Calling BaseTracker::from_config") |
|
|
|
def update(self, predictions: Instances) -> Instances: |
|
""" |
|
Args: |
|
predictions: D2 Instances for predictions of the current frame |
|
Return: |
|
D2 Instances for predictions of the current frame with ID assigned |
|
|
|
_prev_instances and instances will have the following fields: |
|
.pred_boxes (shape=[N, 4]) |
|
.scores (shape=[N,]) |
|
.pred_classes (shape=[N,]) |
|
.pred_keypoints (shape=[N, M, 3], Optional) |
|
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] |
|
.ID (shape=[N,]) |
|
|
|
N: # of detected bboxes |
|
H and W: height and width of 2D mask |
|
""" |
|
raise NotImplementedError("Calling BaseTracker::update") |
|
|
|
|
|
def build_tracker_head(cfg: CfgNode_) -> BaseTracker: |
|
""" |
|
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. |
|
|
|
Args: |
|
cfg: D2 CfgNode, config file with tracker information |
|
Return: |
|
tracker object |
|
""" |
|
name = cfg.TRACKER_HEADS.TRACKER_NAME |
|
tracker_class = TRACKER_HEADS_REGISTRY.get(name) |
|
return tracker_class(cfg) |
|
|