|
from .tracker.byte_tracker import BYTETracker |
|
import cv2 |
|
import numpy as np |
|
|
|
class ByteTrack(object): |
|
def __init__(self, detector, min_box_area=10): |
|
self.min_box_area = min_box_area |
|
|
|
self.rgb_means = (0.485, 0.456, 0.406) |
|
self.std = (0.229, 0.224, 0.225) |
|
|
|
self.detector = detector |
|
self.input_shape = tuple(detector.model.get_inputs()[0].shape[2:]) |
|
self.tracker = BYTETracker(frame_rate=30) |
|
|
|
def inference(self, image, conf_thresh=0.25, classes=None): |
|
|
|
dets, image_info = self.detector.detect(image, conf_thres=conf_thresh, input_shape=self.input_shape, classes=classes) |
|
|
|
class_ids=[] |
|
ids=[] |
|
bboxes=[] |
|
scores=[] |
|
|
|
if isinstance(dets, np.ndarray) and len(dets) > 0: |
|
class_ids = dets[:, -1].tolist() |
|
bboxes, ids, scores = self._tracker_update( |
|
dets, |
|
image_info, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return bboxes, ids, scores, class_ids |
|
|
|
def get_id_color(self, index): |
|
temp_index = abs(int(index)) * 3 |
|
color = ((37 * temp_index) % 255, (17 * temp_index) % 255, |
|
(29 * temp_index) % 255) |
|
return color |
|
|
|
def draw_tracking_info( |
|
self, |
|
image, |
|
tlwhs, |
|
ids, |
|
scores, |
|
frame_id=0, |
|
elapsed_time=0., |
|
): |
|
text_scale = 1.5 |
|
text_thickness = 2 |
|
line_thickness = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for index, tlwh in enumerate(tlwhs): |
|
x1, y1 = int(tlwh[0]), int(tlwh[1]) |
|
x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3]) |
|
color = self.get_id_color(ids[index]) |
|
cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness) |
|
|
|
text = str(ids[index]) |
|
cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, |
|
text_scale, (0, 0, 0), text_thickness + 3) |
|
cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, |
|
text_scale, (255, 255, 255), text_thickness) |
|
return image |
|
|
|
def _tracker_update(self, dets, image_info): |
|
online_targets = [] |
|
if dets is not None: |
|
online_targets = self.tracker.update( |
|
dets[:, :-1], |
|
[image_info['height'], image_info['width']], |
|
[image_info['height'], image_info['width']], |
|
) |
|
online_tlwhs = [] |
|
online_ids = [] |
|
online_scores = [] |
|
for online_target in online_targets: |
|
tlwh = online_target.tlwh |
|
track_id = online_target.track_id |
|
vertical = tlwh[2] / tlwh[3] > 1.6 |
|
if tlwh[2] * tlwh[3] > self.min_box_area and not vertical: |
|
online_tlwhs.append(tlwh) |
|
online_ids.append(track_id) |
|
online_scores.append(online_target.score) |
|
|
|
return online_tlwhs, online_ids, online_scores |