File size: 3,413 Bytes
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118


import numpy as np
import cv2
from typing import Tuple, List, Union
from .base_onnx import BaseONNX


class RTMDET_ONNX(BaseONNX):

    def __init__(self, model_path, input_size=(640, 640)):
        super().__init__(model_path, input_size)

    def preprocess_image(self, img_bgr: cv2.UMat):
        # 调整图片大小
        img_bgr = cv2.resize(img_bgr, self.input_size)

        # normalize mean and std
        img = (img_bgr - np.array([103.53, 116.28, 123.675])) / np.array([57.375, 57.12, 58.395])

        img = img.astype(np.float32)
        # 转换为浮点型并归一化
        # img = img.astype(np.float32) / 255.0
        
        # 调整维度顺序 (H,W,C) -> (C,H,W)
        img = np.transpose(img, (2, 0, 1))
        
        # 添加 batch 维度
        img = np.expand_dims(img, axis=0)

        return img
    

    def run_inference(self, image: np.ndarray):
        """
        Run inference on the image.

        Args:
            image (np.ndarray): The image to run inference on.

        Returns:
            tuple: A tuple containing the detection results and labels.
        """
        # 运行推理
        outputs = self.session.run(None, {self.input_name: image})

        """
        dets: 检测框 [batch, num_dets, [x1, y1, x2, y2, conf]] ([batch, num_dets, Reshape(dets_dim_2)])
        labels: 标签 [batch,num_dets]
        """
        dets, labels = outputs

        return dets, labels

    def pred(self, image: List[Union[cv2.UMat, str]]) -> Tuple[List[int], float]:
        """
        Predict the detection results of the image.

        Args:
            image (cv2.UMat, str): The image to predict.

        Returns:
           xyxy (list[int, int, int, int]): The detection results.
           conf (float): The confidence of the detection results.
        """
        if isinstance(image, str):
            img_bgr = cv2.imread(image)
        else:
            img_bgr = image.copy()

        original_w, original_h = img_bgr.shape[1], img_bgr.shape[0]

        image = self.preprocess_image(img_bgr)
        dets, labels = self.run_inference(image)

        # 获取置信度最高的检测框
        # dets = dets[0][0]
        # labels = labels[0][0]

        x1, y1, x2, y2, conf = dets[0][0]

        xyxy = [x1, y1, x2, y2]

        xyxy = self.transform_xyxy_to_original(xyxy, original_w, original_h)

        return xyxy, conf
    
    def transform_xyxy_to_original(self, xyxy, original_w, original_h) -> List[int]:
        """
        将检测框从输入图像的尺寸转换为原始图像的尺寸
        """
        x1, y1, x2, y2 = xyxy

        input_w, input_h = self.input_size
        ratio_w, ratio_h = original_w / input_w, original_h / input_h

        x1, y1, x2, y2 = x1 * ratio_w, y1 * ratio_h, x2 * ratio_w, y2 * ratio_h
        # 转换为整数
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

        return [x1, y1, x2, y2]
    
    def draw_pred(self, img: cv2.UMat, xyxy: List[int], conf: float, is_rgb: bool = True) -> cv2.UMat:
        """
        Draw the detection results on the image.
        """

        if not is_rgb:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        x1, y1, x2, y2 = xyxy

        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
        cv2.putText(img, f"{conf:.2f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

        return img