Spaces:
Sleeping
Sleeping
import ultralytics | |
import onemetric | |
import supervision | |
import typing | |
import tqdm | |
import os | |
from ultralytics import YOLO | |
from dataclasses import dataclass | |
from onemetric.cv.utils.iou import box_iou_batch | |
from supervision import Point | |
from supervision import Detections, BoxAnnotator | |
from supervision import draw_text | |
from supervision import Color | |
from supervision import VideoInfo | |
from supervision import get_video_frames_generator | |
from supervision import VideoSink | |
import torch | |
os.system("pip install git+https://github.com/ifzhang/ByteTrack") | |
from typing import List | |
import numpy as np | |
import gradio as gr | |
from tqdm import tqdm | |
import yolox | |
from yolox.tracker.byte_tracker import BYTETracker, STrack | |
MODEL = "./best.pt" | |
SOURCE_VIDEO_PATH = "./examples" | |
TARGET_VIDEO_PATH = "test.mp4" | |
CLASS_ID = [0,1,2,3] | |
model = YOLO(MODEL) | |
model.fuse() | |
classes = CLASS_ID | |
class BYTETrackerArgs: | |
track_thresh: float = 0.25 | |
track_buffer: int = 30 | |
match_thresh: float = 0.8 | |
aspect_ratio_thresh: float = 3.0 | |
min_box_area: float = 1.0 | |
mot20: bool = False | |
# converts Detections into format that can be consumed by match_detections_with_tracks function | |
def detections2boxes(detections : Detections) -> np.ndarray: | |
return np.hstack(( | |
detections.xyxy, | |
detections.confidence[:, np.newaxis] | |
)) | |
# converts List[STrack] into format that can be consumed by match_detections_with_tracks function | |
def tracks2boxes(tracks: List[STrack]) -> np.ndarray: | |
return np.array([ | |
track.tlbr | |
for track | |
in tracks | |
], dtype=float) | |
# matches our bounding boxes with predictions | |
def match_detections_with_tracks( | |
detections: Detections, | |
tracks: List[STrack], | |
) -> Detections: | |
if not np.any(detections.xyxy) or len(tracks) == 0: | |
return np.empty((0,)) | |
tracks_boxes = tracks2boxes(tracks=tracks) | |
iou = box_iou_batch(tracks_boxes, detections.xyxy) | |
track2detection = np.argmax(iou, axis=1) | |
tracker_ids = [None] * len(detections) | |
for tracker_index, detection_index in enumerate(track2detection): | |
if iou[tracker_index, detection_index] != 0: | |
tracker_ids[detection_index] = tracks[tracker_index].track_id | |
return tracker_ids | |
def ObjectDetection(video_path): | |
byte_tracker = BYTETracker(BYTETrackerArgs()) | |
video_info = VideoInfo.from_video_path(video_path) | |
generator = get_video_frames_generator(video_path) | |
box_annotator = BoxAnnotator(thickness=5, text_thickness=5, text_scale=1) | |
with VideoSink(TARGET_VIDEO_PATH, video_info) as sink: | |
# loop over video frames | |
for frame in tqdm(generator, total=video_info.total_frames): | |
results = model(frame) | |
detections = Detections( | |
xyxy=results[0].boxes.xyxy.cpu().numpy(), | |
confidence=results[0].boxes.conf.cpu().numpy(), | |
class_id=results[0].boxes.cls.cpu().numpy().astype(int) | |
) | |
# filtering out detections with unwanted classes | |
detections = detections[np.isin(detections.class_id, [0,1,2,3])] | |
# tracking detections | |
tracks = byte_tracker.update( | |
output_results=detections2boxes(detections = detections), | |
img_info=frame.shape, | |
img_size=frame.shape | |
) | |
tracker_id = match_detections_with_tracks(detections=detections, tracks=tracks) | |
detections.tracker_id = np.array(tracker_id) | |
# filtering out detections without trackers | |
detections = detections[np.not_equal(detections.tracker_id, None)] | |
# format custom labels | |
labels = [ | |
f"#{tracker_id} {classes[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, tracker_id | |
in detections | |
] | |
t = np.unique(detections.class_id, return_counts =True) | |
for x in zip(t[0], t[1]): | |
frame = draw_text(background_color=Color.white(), scene=frame, text=' '.join((str(classes[x[0]]), ':', str(x[1]))), text_anchor=Point(x=50, y=300 + (50 * x[0])), text_scale = 2, text_thickness = 4, ) | |
# annotate and display frame | |
frame = box_annotator.annotate(scene=frame, detections=detections, labels=labels) | |
sink.write_frame(frame) | |
return TARGET_VIDEO_PATH | |
demo = gr.Interface(fn=ObjectDetection, inputs=gr.Video(), outputs=gr.Video(), examples=SOURCE_VIDEO_PATH, cache_examples=False) | |
demo.launch() |