StableVITON / detectron2 /tracking /base_tracker.py
rlawjdghek's picture
det2 (#6)
1527335 verified
raw
history blame
2.21 kB
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
from detectron2.config import configurable
from detectron2.utils.registry import Registry
from ..config.config import CfgNode as CfgNode_
from ..structures import Instances
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
TRACKER_HEADS_REGISTRY.__doc__ = """
Registry for tracking classes.
"""
class BaseTracker:
"""
A parent class for all trackers
"""
@configurable
def __init__(self, **kwargs):
self._prev_instances = None # (D2)instances for previous frame
self._matched_idx = set() # indices in prev_instances found matching
self._matched_ID = set() # idendities in prev_instances found matching
self._untracked_prev_idx = set() # indices in prev_instances not found matching
self._id_count = 0 # used to assign new id
@classmethod
def from_config(cls, cfg: CfgNode_):
raise NotImplementedError("Calling BaseTracker::from_config")
def update(self, predictions: Instances) -> Instances:
"""
Args:
predictions: D2 Instances for predictions of the current frame
Return:
D2 Instances for predictions of the current frame with ID assigned
_prev_instances and instances will have the following fields:
.pred_boxes (shape=[N, 4])
.scores (shape=[N,])
.pred_classes (shape=[N,])
.pred_keypoints (shape=[N, M, 3], Optional)
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W]
.ID (shape=[N,])
N: # of detected bboxes
H and W: height and width of 2D mask
"""
raise NotImplementedError("Calling BaseTracker::update")
def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
"""
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`.
Args:
cfg: D2 CfgNode, config file with tracker information
Return:
tracker object
"""
name = cfg.TRACKER_HEADS.TRACKER_NAME
tracker_class = TRACKER_HEADS_REGISTRY.get(name)
return tracker_class(cfg)