File size: 3,512 Bytes
efe15e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from .tracker.byte_tracker import BYTETracker
import cv2
import numpy as np

class ByteTrack(object):
    def __init__(self, detector, min_box_area=10):
        self.min_box_area = min_box_area

        self.rgb_means = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

        self.detector = detector
        self.input_shape = tuple(detector.model.get_inputs()[0].shape[2:])
        self.tracker = BYTETracker(frame_rate=30)

    def inference(self, image, conf_thresh=0.25, classes=None):
 
        dets, image_info = self.detector.detect(image, conf_thres=conf_thresh, input_shape=self.input_shape, classes=classes)
        
        class_ids=[]
        ids=[]
        bboxes=[]
        scores=[]

        if isinstance(dets, np.ndarray) and len(dets) > 0:
            class_ids = dets[:, -1].tolist()
            bboxes, ids, scores = self._tracker_update(
                dets,
                image_info,
            )
            # image = self.draw_tracking_info(
            #     image,
            #     bboxes,
            #     ids,
            #     scores,
            # )

        # return image, len(bboxes), class_ids
        return bboxes, ids, scores, class_ids
        
    def get_id_color(self, index):
        temp_index = abs(int(index)) * 3
        color = ((37 * temp_index) % 255, (17 * temp_index) % 255,
                (29 * temp_index) % 255)
        return color

    def draw_tracking_info(
        self,
        image,
        tlwhs,
        ids,
        scores,
        frame_id=0,
        elapsed_time=0.,
    ):
        text_scale = 1.5
        text_thickness = 2
        line_thickness = 2

        # text = 'frame: %d ' % (frame_id)
        # text += 'elapsed time: %.0fms ' % (elapsed_time * 1000)
        # text += 'num: %d' % (len(tlwhs))
        # cv2.putText(
        #     image,
        #     text,
        #     (0, int(15 * text_scale)),
        #     cv2.FONT_HERSHEY_PLAIN,
        #     2,
        #     (0, 255, 0),
        #     thickness=text_thickness,
        # )
        
        for index, tlwh in enumerate(tlwhs):
            x1, y1 = int(tlwh[0]), int(tlwh[1])
            x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3])
            color = self.get_id_color(ids[index])
            cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness)

            text = str(ids[index])
            cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
                        text_scale, (0, 0, 0), text_thickness + 3)
            cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN,
                        text_scale, (255, 255, 255), text_thickness)
        return image

    def _tracker_update(self, dets, image_info):
        online_targets = []
        if dets is not None:
            online_targets = self.tracker.update(
                dets[:, :-1],
                [image_info['height'], image_info['width']],
                [image_info['height'], image_info['width']],
            )
        online_tlwhs = []
        online_ids = []
        online_scores = []
        for online_target in online_targets:
            tlwh = online_target.tlwh
            track_id = online_target.track_id
            vertical = tlwh[2] / tlwh[3] > 1.6
            if tlwh[2] * tlwh[3] > self.min_box_area and not vertical:
                online_tlwhs.append(tlwh)
                online_ids.append(track_id)
                online_scores.append(online_target.score)

        return online_tlwhs, online_ids, online_scores