File size: 6,028 Bytes
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import onnxruntime
import numpy as np
import cv2
from typing import Tuple, List, Union
from .base_onnx import BaseONNX


def center_crop(image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
    """
    Center crop the image to the target size.

    Args:
        image (np.ndarray): The input image.
        target_size (Tuple[int, int]): The desired output size (height, width).

    Returns:
        np.ndarray: The cropped image.
    """
    h, w, _ = image.shape
    target_w, target_h = target_size

    center_x = w // 2
    center_y = h // 2

    start_x = int(center_x - target_w // 2)
    start_y = int(center_y - target_h // 2)

    cropped_image = image[start_y:start_y + target_h, start_x:start_x + target_w]

    return cropped_image


dict_cate_names = {
    'point': '.',
    'other': 'x',
    'red_king': 'K',
    'red_advisor': 'A',
    'red_bishop': 'B',
    'red_knight': 'N',
    'red_rook': 'R',
    'red_cannon': 'C',
    'red_pawn': 'P',
    'black_king': 'k',
    'black_advisor': 'a',
    'black_bishop': 'b',
    'black_knight': 'n',
    'black_rook': 'r',
    'black_cannon': 'c',
    'black_pawn': 'p',
}

class FULL_CLASSIFIER_ONNX(BaseONNX):

    label_2_short = dict_cate_names

    classes_labels = list(dict_cate_names.keys())

    def __init__(self, 
                 model_path,
                 # 输入图片大小
                 input_size=(280, 315), # (w, h)
                 # 图片裁剪大小
                 crop_size=(400, 450), # (w, h)
                 ):
        super().__init__(model_path, input_size)

        self.crop_size = crop_size


    def preprocess_image(self, img_bgr: cv2.UMat, is_rgb: bool = True):

        if not is_rgb:
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        else:
            img_rgb = img_bgr

        if img_rgb.shape[:2] != self.crop_size:
            # 调整图片大小 执行 center crop
            img_rgb = center_crop(img_rgb, self.crop_size) # dst_size = (w, h)

        # resize 到 input_size
        img_rgb = cv2.resize(img_rgb, self.input_size)

        # normalize mean and std
        img = (img_rgb - np.array([ 123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])

        img = img.astype(np.float32)
        # 转换为浮点型并归一化
        # img = img.astype(np.float32) / 255.0
        
        # 调整维度顺序 (H,W,C) -> (C,H,W)
        img = np.transpose(img, (2, 0, 1))
        
        # 添加 batch 维度
        img = np.expand_dims(img, axis=0)

        return img
    

    def run_inference(self, image: np.ndarray) -> np.ndarray:
        """
        Run inference on the image.

        Args:
            image (np.ndarray): The image to run inference on.

        Returns:
            tuple: A tuple containing the detection results and labels.
        """
        # 运行推理
        outputs, = self.session.run(None, {self.input_name: image})
        
        return outputs

    def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[List[str]], List[List[str]], List[List[float]], str]:
        """
        Predict the detection results of the image.

        Args:
            image (cv2.UMat, str): The image to predict.

        Returns:
          
        """
        if isinstance(image, str):
            img_bgr = cv2.imread(image)
            is_rgb = False
        else:
            img_bgr = image.copy()

        image = self.preprocess_image(img_bgr, is_rgb)

        labels = self.run_inference(image)

        # 校验 labels 的 shape
        assert labels.shape[1:] == (90, 16)

        # shape (90, 16)
        first_batch_labels = labels[0]

        # 获取置信度最高的标签
        # list[int]
        label_indexes = np.argmax(first_batch_labels, axis=-1).tolist()

        # 将标签索引转换为标签
        # list[str]
        label_names = [self.classes_labels[index] for index in label_indexes]

        # list[str]
        label_short = [self.label_2_short[name] for name in label_names]

        # 获取置信度, 根据  first_batch_labels 和 label_indexes
        confidence = first_batch_labels[np.arange(first_batch_labels.shape[0]), label_indexes]

        label_names_10x9 = [label_names[i*9:(i+1)*9] for i in range(10)]
        label_short_10x9 = [label_short[i*9:(i+1)*9] for i in range(10)]
        confidence_10x9 = [confidence[i*9:(i+1)*9] for i in range(10)]


        layout_str = "\n".join(["".join(row) for row in label_short_10x9])

        return label_names_10x9, label_short_10x9, confidence_10x9, layout_str
    
    def draw_pred(self, image: cv2.UMat, label_index: int, label_name: str, label_short: str, confidence: float) -> cv2.UMat:

        # 在图像上绘制预测结果
        cv2.putText(image, f"{label_short} {confidence:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        return image
    

    def draw_pred_with_result(self, image: cv2.UMat, results: List[Tuple[int, str, str, float]], cells_xyxy: np.ndarray, is_rgb: bool = True) -> cv2.UMat:

        assert len(results) == cells_xyxy.shape[0]

        if not is_rgb:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

        for i, (label_index, label_name, label_short, confidence) in enumerate(results):
            # 确保坐标是整数类型
            x1, y1, x2, y2 = map(int, cells_xyxy[i])

            if label_name.startswith('red'):
                color = (180, 105, 255) # 粉红色
            elif label_name.startswith('black'):
                color = (0, 100, 50) # 黑色
            else:
                color = (0, 0, 255) # 蓝色

            
            if confidence < 0.5:
                # yellow 
                color = (255, 255, 0)

            # confidence:.2f 仅保留两位小数 移除

            label_str = f"{label_short} {confidence:.2f}" if confidence < 0.9 else f"{label_short}"

            cv2.putText(image, label_str, (x1 + 8, y2 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)

        return image