File size: 3,549 Bytes
fca2efd
210ae8c
3647577
7204716
fca2efd
210ae8c
 
de42222
210ae8c
7204716
088dea4
 
fca2efd
210ae8c
fca2efd
 
 
 
0cb2d2a
 
 
 
fca2efd
 
 
9d1a8a7
7204716
088dea4
 
b17ce0d
fca2efd
210ae8c
 
 
 
 
 
 
 
fca2efd
 
 
 
30935b4
fca2efd
 
8a03fb5
de42222
 
 
 
 
7204716
 
 
 
 
 
 
 
 
 
 
 
 
 
210ae8c
fca2efd
 
 
 
 
 
 
 
9d1a8a7
7204716
9d1a8a7
fca2efd
 
7204716
 
 
 
 
fca2efd
 
 
 
 
 
 
 
 
 
7204716
8a03fb5
 
 
 
 
7204716
 
9d1a8a7
 
 
 
 
 
 
 
 
fca2efd
 
367d735
 
b17ce0d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
from norfair import AbsolutePaths, Paths, Tracker, Video
from norfair.camera_motion import HomographyTransformationGetter, MotionEstimator
from norfair.distances import create_normalized_mean_euclidean_distance

from custom_models import YOLO, yolo_detections_to_norfair_detections
from demo_utils.configuration import (
    DISTANCE_THRESHOLD_BBOX,
    DISTANCE_THRESHOLD_CENTROID,
    examples,
    models_path,
    style,
)
from demo_utils.draw import center, draw


def inference(
    input_video: str,
    model: str = "YOLOv7",
    features: str = [0, 1],
    track_points: str = "Bounding box",
    model_threshold: float = 0.25,
):
    coord_transformations = None
    paths_drawer = None
    fix_paths = False
    classes = None
    track_points = style[track_points]
    model = YOLO(models_path[model])
    video = Video(input_path=input_video)

    motion_estimation = len(features) > 0 and (
        features[0] == 0 or (len(features) > 1 and features[1] == 0)
    )

    drawing_paths = len(features) > 0 and (
        features[0] == 1 or (len(features) > 1 and features[1] == 1)
    )

    if motion_estimation:
        transformations_getter = HomographyTransformationGetter()

        motion_estimator = MotionEstimator(
            max_points=500, min_distance=7, transformations_getter=transformations_getter
        )

    distance_function = "iou" if track_points == style["Bounding box"] else "euclidean"
    distance_threshold = (
        DISTANCE_THRESHOLD_BBOX
        if track_points == style["Bounding box"]
        else DISTANCE_THRESHOLD_CENTROID
    )

    if motion_estimation and drawing_paths:
        fix_paths = True

    # Examples configuration
    for example in examples:
        if example not in input_video:
            continue
        fix_paths = examples[example]["absolute_path"]
        distance_threshold = examples[example]["distance_threshold"]
        classes = examples[example]["classes"]

        print(f"Set config to {example}: {fix_paths} {distance_threshold} {classes}")
        break

    tracker = Tracker(
        distance_function=distance_function,
        distance_threshold=distance_threshold,
    )

    if drawing_paths:
        paths_drawer = Paths(center, attenuation=0.01)

    if fix_paths:
        paths_drawer = AbsolutePaths(max_history=50, thickness=2)

    for frame in video:
        yolo_detections = model(
            frame,
            conf_threshold=model_threshold,
            iou_threshold=0.45,
            image_size=720,
            classes=classes,
        )

        detections = yolo_detections_to_norfair_detections(
            yolo_detections, track_points=track_points
        )

        tracked_objects = tracker.update(
            detections=detections, coord_transformations=coord_transformations
        )

        if motion_estimation:
            mask = np.ones(frame.shape[:2], frame.dtype)
            if track_points == "bbox":
                for det in detections:
                    i = det.points.astype(int)
                    mask[i[0, 1] : i[1, 1], i[0, 0] : i[1, 0]] = 0
            coord_transformations = motion_estimator.update(frame, mask)

        frame = draw(
            paths_drawer,
            track_points,
            frame,
            detections,
            tracked_objects,
            coord_transformations,
            fix_paths,
        )
        video.write(frame)

    base_file_name = input_video.split("/")[-1].split(".")[0]
    file_name = base_file_name + "_out.mp4"

    return file_name