norfair-demo / inference.py
Diego Fernandez
fix: cache examples
0cb2d2a
raw
history blame
3.03 kB
import os
import tempfile
import numpy as np
from norfair import AbsolutePaths, Paths, Tracker, Video
from norfair.camera_motion import HomographyTransformationGetter, MotionEstimator
from custom_models import YOLO, yolo_detections_to_norfair_detections
from demo_utils.configuration import (
DISTANCE_THRESHOLD_BBOX,
DISTANCE_THRESHOLD_CENTROID,
models_path,
style,
)
from demo_utils.distance_function import euclidean_distance, iou
from demo_utils.draw import center, draw
from demo_utils.files import get_files
def inference(
input_video: str,
model: str = "YOLOv7",
features: str = [0, 1],
track_points: str = "Bounding box",
model_threshold: float = 0.25,
):
# temp_dir = tempfile.TemporaryDirectory()
# output_path = temp_dir.name
coord_transformations = None
paths_drawer = None
fix_paths = False
track_points = style[track_points]
model = YOLO(models_path[model])
video = Video(input_path=input_video)
motion_estimation = len(features) > 0 and (
features[0] == 0 or (len(features) > 1 and features[1] == 0)
)
drawing_paths = len(features) > 0 and (
features[0] == 1 or (len(features) > 1 and features[1] == 1)
)
if motion_estimation and drawing_paths:
fix_paths = True
if motion_estimation:
transformations_getter = HomographyTransformationGetter()
motion_estimator = MotionEstimator(
max_points=500, min_distance=7, transformations_getter=transformations_getter
)
distance_function = iou if track_points == "bbox" else euclidean_distance
distance_threshold = (
DISTANCE_THRESHOLD_BBOX if track_points == "bbox" else DISTANCE_THRESHOLD_CENTROID
)
tracker = Tracker(
distance_function=distance_function,
distance_threshold=distance_threshold,
)
if drawing_paths:
paths_drawer = Paths(center, attenuation=0.01)
if fix_paths:
paths_drawer = AbsolutePaths(max_history=15, thickness=2)
for frame in video:
yolo_detections = model(
frame, conf_threshold=model_threshold, iou_threshold=0.45, image_size=720
)
mask = np.ones(frame.shape[:2], frame.dtype)
if motion_estimation:
coord_transformations = motion_estimator.update(frame, mask)
detections = yolo_detections_to_norfair_detections(
yolo_detections, track_points=track_points
)
tracked_objects = tracker.update(
detections=detections, coord_transformations=coord_transformations
)
frame = draw(
paths_drawer,
track_points,
frame,
detections,
tracked_objects,
coord_transformations,
fix_paths,
)
video.write(frame)
base_file_name = input_video.split("/")[-1].split(".")[0]
file_name = base_file_name + "_out.mp4"
# return os.path.join(output_path, file_name)
return file_name