yolo12138's picture
publish v1
37170d6
raw
history blame
3.41 kB
import numpy as np
import cv2
from typing import Tuple, List, Union
from .base_onnx import BaseONNX
class RTMDET_ONNX(BaseONNX):
def __init__(self, model_path, input_size=(640, 640)):
super().__init__(model_path, input_size)
def preprocess_image(self, img_bgr: cv2.UMat):
# 调整图片大小
img_bgr = cv2.resize(img_bgr, self.input_size)
# normalize mean and std
img = (img_bgr - np.array([103.53, 116.28, 123.675])) / np.array([57.375, 57.12, 58.395])
img = img.astype(np.float32)
# 转换为浮点型并归一化
# img = img.astype(np.float32) / 255.0
# 调整维度顺序 (H,W,C) -> (C,H,W)
img = np.transpose(img, (2, 0, 1))
# 添加 batch 维度
img = np.expand_dims(img, axis=0)
return img
def run_inference(self, image: np.ndarray):
"""
Run inference on the image.
Args:
image (np.ndarray): The image to run inference on.
Returns:
tuple: A tuple containing the detection results and labels.
"""
# 运行推理
outputs = self.session.run(None, {self.input_name: image})
"""
dets: 检测框 [batch, num_dets, [x1, y1, x2, y2, conf]] ([batch, num_dets, Reshape(dets_dim_2)])
labels: 标签 [batch,num_dets]
"""
dets, labels = outputs
return dets, labels
def pred(self, image: List[Union[cv2.UMat, str]]) -> Tuple[List[int], float]:
"""
Predict the detection results of the image.
Args:
image (cv2.UMat, str): The image to predict.
Returns:
xyxy (list[int, int, int, int]): The detection results.
conf (float): The confidence of the detection results.
"""
if isinstance(image, str):
img_bgr = cv2.imread(image)
else:
img_bgr = image.copy()
original_w, original_h = img_bgr.shape[1], img_bgr.shape[0]
image = self.preprocess_image(img_bgr)
dets, labels = self.run_inference(image)
# 获取置信度最高的检测框
# dets = dets[0][0]
# labels = labels[0][0]
x1, y1, x2, y2, conf = dets[0][0]
xyxy = [x1, y1, x2, y2]
xyxy = self.transform_xyxy_to_original(xyxy, original_w, original_h)
return xyxy, conf
def transform_xyxy_to_original(self, xyxy, original_w, original_h) -> List[int]:
"""
将检测框从输入图像的尺寸转换为原始图像的尺寸
"""
x1, y1, x2, y2 = xyxy
input_w, input_h = self.input_size
ratio_w, ratio_h = original_w / input_w, original_h / input_h
x1, y1, x2, y2 = x1 * ratio_w, y1 * ratio_h, x2 * ratio_w, y2 * ratio_h
# 转换为整数
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
return [x1, y1, x2, y2]
def draw_pred(self, img: cv2.UMat, xyxy: List[int], conf: float, is_rgb: bool = True) -> cv2.UMat:
"""
Draw the detection results on the image.
"""
if not is_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
x1, y1, x2, y2 = xyxy
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(img, f"{conf:.2f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
return img