Diego Fernandez
fix: paths with centroids
57a060c
raw
history blame
2.82 kB
import os
from typing import List, Optional, Union
import numpy as np
import torch
from norfair import Detection
class YOLO:
def __init__(self, model_path: str, device: Optional[str] = None):
if device is not None and "cuda" in device and not torch.cuda.is_available():
raise Exception("Selected device='cuda', but cuda is not available to Pytorch.")
# automatically set device if its None
elif device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if not os.path.exists(model_path):
os.system(
f"wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/{os.path.basename(model_path)} -O {model_path}"
)
# load model
try:
self.model = torch.hub.load("WongKinYiu/yolov7", "custom", model_path)
except:
raise Exception("Failed to load model from {}".format(model_path))
def __call__(
self,
img: Union[str, np.ndarray],
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
classes: Optional[List[int]] = None,
) -> torch.tensor:
self.model.conf = conf_threshold
self.model.iou = iou_threshold
if classes is not None:
self.model.classes = classes
detections = self.model(img, size=image_size)
return detections
def yolo_detections_to_norfair_detections(
yolo_detections: torch.tensor, track_points: str = "centroid" # bbox or centroid
) -> List[Detection]:
"""convert detections_as_xywh to norfair detections"""
norfair_detections: List[Detection] = []
if track_points == "centroid":
detections_as_xywh = yolo_detections.xywh[0]
for detection_as_xywh in detections_as_xywh:
centroid = np.array(
[
[detection_as_xywh[0].item(), detection_as_xywh[1].item()],
[detection_as_xywh[0].item(), detection_as_xywh[1].item()],
]
)
scores = np.array([detection_as_xywh[4].item(), detection_as_xywh[4].item()])
norfair_detections.append(Detection(points=centroid, scores=scores))
elif track_points == "bbox":
detections_as_xyxy = yolo_detections.xyxy[0]
for detection_as_xyxy in detections_as_xyxy:
bbox = np.array(
[
[detection_as_xyxy[0].item(), detection_as_xyxy[1].item()],
[detection_as_xyxy[2].item(), detection_as_xyxy[3].item()],
]
)
scores = np.array([detection_as_xyxy[4].item(), detection_as_xyxy[4].item()])
norfair_detections.append(Detection(points=bbox, scores=scores))
return norfair_detections