productcounter / app.py
Pedro Henrique Conrado
1 commit
8d6c487
raw
history blame
4.61 kB
import ultralytics
import onemetric
import supervision
import typing
import tqdm
import os
from ultralytics import YOLO
from dataclasses import dataclass
from onemetric.cv.utils.iou import box_iou_batch
from supervision import Point
from supervision import Detections, BoxAnnotator
from supervision import draw_text
from supervision import Color
from supervision import VideoInfo
from supervision import get_video_frames_generator
from supervision import VideoSink
import torch
from typing import List
import numpy as np
import gradio as gr
from tqdm import tqdm
import git
git.Git("./").clone("[email protected]:Megvii-BaseDetection/YOLOX.git")
os.chdir("./YOLOX")
os.system("pip3 install -v -e .")
os.chdir("./..")
import yolox
from yolox.tracker.byte_tracker import BYTETracker, STrack
MODEL = "./best.pt"
SOURCE_VIDEO_PATH = "./examples"
TARGET_VIDEO_PATH = "test.mp4"
CLASS_ID = [0,1,2,3]
model = YOLO(MODEL)
model.fuse()
classes = CLASS_ID
@dataclass(frozen=True)
class BYTETrackerArgs:
track_thresh: float = 0.25
track_buffer: int = 30
match_thresh: float = 0.8
aspect_ratio_thresh: float = 3.0
min_box_area: float = 1.0
mot20: bool = False
# converts Detections into format that can be consumed by match_detections_with_tracks function
def detections2boxes(detections : Detections) -> np.ndarray:
return np.hstack((
detections.xyxy,
detections.confidence[:, np.newaxis]
))
# converts List[STrack] into format that can be consumed by match_detections_with_tracks function
def tracks2boxes(tracks: List[STrack]) -> np.ndarray:
return np.array([
track.tlbr
for track
in tracks
], dtype=float)
# matches our bounding boxes with predictions
def match_detections_with_tracks(
detections: Detections,
tracks: List[STrack],
) -> Detections:
if not np.any(detections.xyxy) or len(tracks) == 0:
return np.empty((0,))
tracks_boxes = tracks2boxes(tracks=tracks)
iou = box_iou_batch(tracks_boxes, detections.xyxy)
track2detection = np.argmax(iou, axis=1)
tracker_ids = [None] * len(detections)
for tracker_index, detection_index in enumerate(track2detection):
if iou[tracker_index, detection_index] != 0:
tracker_ids[detection_index] = tracks[tracker_index].track_id
return tracker_ids
def ObjectDetection(video_path):
byte_tracker = BYTETracker(BYTETrackerArgs())
video_info = VideoInfo.from_video_path(video_path)
generator = get_video_frames_generator(video_path)
box_annotator = BoxAnnotator(thickness=5, text_thickness=5, text_scale=1)
with VideoSink(TARGET_VIDEO_PATH, video_info) as sink:
# loop over video frames
for frame in tqdm(generator, total=video_info.total_frames):
results = model(frame)
detections = Detections(
xyxy=results[0].boxes.xyxy.cpu().numpy(),
confidence=results[0].boxes.conf.cpu().numpy(),
class_id=results[0].boxes.cls.cpu().numpy().astype(int)
)
# filtering out detections with unwanted classes
detections = detections[np.isin(detections.class_id, [0,1,2,3])]
# tracking detections
tracks = byte_tracker.update(
output_results=detections2boxes(detections = detections),
img_info=frame.shape,
img_size=frame.shape
)
tracker_id = match_detections_with_tracks(detections=detections, tracks=tracks)
detections.tracker_id = np.array(tracker_id)
# filtering out detections without trackers
detections = detections[np.not_equal(detections.tracker_id, None)]
# format custom labels
labels = [
f"#{tracker_id} {classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, tracker_id
in detections
]
t = np.unique(detections.class_id, return_counts =True)
for x in zip(t[0], t[1]):
frame = draw_text(background_color=Color.white(), scene=frame, text=' '.join((str(classes[x[0]]), ':', str(x[1]))), text_anchor=Point(x=50, y=300 + (50 * x[0])), text_scale = 2, text_thickness = 4, )
# annotate and display frame
frame = box_annotator.annotate(scene=frame, detections=detections, labels=labels)
sink.write_frame(frame)
return TARGET_VIDEO_PATH
demo = gr.Interface(fn=ObjectDetection, inputs=gr.Video(), outputs=gr.Video(), examples=SOURCE_VIDEO_PATH, cache_examples=False)
demo.launch()