yolo12138 commited on
Commit
37170d6
·
1 Parent(s): bb29eec

publish v1

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.DS_Store
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.chessboard_detector import ChessboardDetector
3
+
4
+ detector = ChessboardDetector(
5
+ det_model_path="onnx/det/v1.onnx",
6
+ pose_model_path="onnx/pose/v1.onnx",
7
+ full_classifier_model_path="onnx/layout_recognition/v1.onnx"
8
+ )
9
+
10
+
11
+
12
+ # 数据集路径
13
+ dict_cate_names = {
14
+ '.': '.',
15
+ 'x': 'x',
16
+ '红帅': 'K',
17
+ '红士': 'A',
18
+ '红相': 'B',
19
+ '红马': 'N',
20
+ '红车': 'R',
21
+ '红炮': 'C',
22
+ '红兵': 'P',
23
+
24
+ '黑将': 'k',
25
+ '黑仕': 'a',
26
+ '黑象': 'b',
27
+ '黑傌': 'n',
28
+ '黑車': 'r',
29
+ '黑砲': 'c',
30
+ '黑卒': 'p',
31
+ }
32
+
33
+ dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
34
+
35
+
36
+ with gr.Blocks(
37
+ css="""
38
+ .image {
39
+ max-height: 512px;
40
+ }
41
+ """
42
+ ) as demo:
43
+ gr.Markdown("""
44
+ ## 棋盘检测, 棋子识别
45
+
46
+ 步骤:
47
+ 1. 流程分成两步,第一步检测边缘
48
+ 2. 对整个棋盘画面进行棋子分类预测
49
+ """
50
+ )
51
+
52
+ with gr.Row():
53
+ with gr.Column():
54
+ image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
55
+
56
+ with gr.Column():
57
+ original_image_with_keypoints = gr.Image(
58
+ label="step1: 原图带关键点",
59
+ interactive=False,
60
+ visible=True,
61
+ elem_classes="image"
62
+ )
63
+
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ transformed_image = gr.Image(
68
+ label="step2: 拉伸棋盘",
69
+ interactive=False,
70
+ visible=True,
71
+ elem_classes="image"
72
+ )
73
+
74
+ with gr.Column():
75
+ use_time = gr.Textbox(
76
+ label="用时",
77
+ interactive=False,
78
+ visible=True,
79
+ )
80
+ layout_pred_info = gr.Dataframe(
81
+ label="棋子识别",
82
+ interactive=False,
83
+ visible=True,
84
+ )
85
+
86
+
87
+ def detect_chessboard(image):
88
+ original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)
89
+
90
+ # 将 cells_labels 转换为 DataFrame
91
+ # cells_labels 通过 \n 分割
92
+ annotation_10_rows = [item for item in cells_labels_str.split("\n")]
93
+ # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
94
+ annotation_arr_10_9 = [list(item) for item in annotation_10_rows]
95
+
96
+ # 将 棋子类别 转换为 中文
97
+ annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9]
98
+
99
+
100
+ return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info
101
+
102
+ image_input.change(fn=detect_chessboard,
103
+ inputs=[image_input],
104
+ outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch()
core/chessboard_detector.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import numpy as np
4
+ import cv2
5
+ from typing import List, Tuple, Union
6
+ from pandas import DataFrame
7
+ from .runonnx.rtmdet import RTMDET_ONNX
8
+ from .runonnx.rtmpose import RTMPOSE_ONNX
9
+ from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
10
+
11
+ from core.helper_34 import extract_chessboard
12
+
13
+ class ChessboardDetector:
14
+ def __init__(self,
15
+ det_model_path: str,
16
+ pose_model_path: str,
17
+ full_classifier_model_path: str = None
18
+ ):
19
+
20
+ self.det = RTMDET_ONNX(
21
+ model_path=det_model_path,
22
+ )
23
+
24
+
25
+ self.pose = RTMPOSE_ONNX(
26
+ model_path=pose_model_path,
27
+ )
28
+
29
+ if full_classifier_model_path is not None:
30
+ self.full_classifier = FULL_CLASSIFIER_ONNX(
31
+ model_path=full_classifier_model_path,
32
+ )
33
+
34
+ self.board_positions = [] # 存储棋盘位置坐标
35
+ self.current_image = None
36
+ self.current_filename = None
37
+
38
+
39
+ # 检测中国象棋棋盘
40
+ def pred_detect_and_keypoints(self, image_bgr: Union[np.ndarray, None] = None) -> Tuple[List[int], float, List[List[int]], List[float]]:
41
+
42
+ xyxy, conf = self.det.pred(image_bgr)
43
+
44
+ # 预测关键点, 绘制关键点
45
+ keypoints, scores = self.pose.pred(image=image_bgr, bbox=xyxy)
46
+
47
+ return xyxy, conf, keypoints, scores
48
+
49
+
50
+ def draw_pred_with_keypoints(self, image_rgb: Union[np.ndarray, None] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
51
+ if image_rgb is None:
52
+ return None, None, None, None
53
+
54
+ image_rgb = image_rgb.copy()
55
+
56
+ original_image = image_rgb.copy()
57
+
58
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
59
+
60
+ xyxy, conf, keypoints, scores = self.pred_detect_and_keypoints(image_bgr)
61
+
62
+ # 绘制棋盘框架
63
+ draw_image = self.det.draw_pred(image_rgb, xyxy, conf)
64
+
65
+ # 绘制关键点
66
+ draw_image = self.pose.draw_pred(img=draw_image, keypoints=keypoints, scores=scores)
67
+
68
+ # 融合 self.pose.bone_names 与 keypoints, 再转换成 DataFrame
69
+ keypoint_list = []
70
+ for bone_name, keypoint in zip(self.pose.bone_names, keypoints):
71
+ keypoint_list.append({"name": bone_name, "x": keypoint[0], "y": keypoint[1]})
72
+
73
+ keypoint_df = DataFrame(keypoint_list)
74
+
75
+ return draw_image, original_image, [xyxy], keypoint_df
76
+
77
+ # 拉伸棋盘 detect board, 然后预测
78
+ def extract_chessboard_and_classifier_layout(self,
79
+ image_rgb: Union[np.ndarray, None] = None,
80
+ keypoints: Union[np.ndarray, None] = None
81
+ ) -> Tuple[np.ndarray, List[List[str]], List[List[float]]]:
82
+
83
+ # 提取棋盘, 绘制 每个位置的 范围信息
84
+ transformed_image, _transformed_keypoints, _corner_points = extract_chessboard(img=image_rgb, keypoints=keypoints)
85
+
86
+ transformed_image_copy = transformed_image.copy()
87
+
88
+ # 预测每个位置的 棋子类别
89
+ _, _, scores, pred_result = self.full_classifier.pred(transformed_image_copy, is_rgb=True)
90
+
91
+
92
+ return transformed_image, pred_result, scores
93
+
94
+
95
+ # 检测棋盘 detect board
96
+ def pred_detect_board_and_classifier(self,
97
+ image_rgb: Union[np.ndarray, None] = None,
98
+ ) -> Tuple[np.ndarray, np.ndarray, str, List[List[float]], str]:
99
+
100
+ """
101
+ @param image_rgb: 输入的 RGB 图像
102
+ @return:
103
+ - transformed_image_layout # 拉伸棋盘
104
+ - original_image_with_keypoints # 原图关键点
105
+ - layout_pred_info # 每个位置的 棋子类别
106
+ - scores # 每个位置的 置信度
107
+ - time_info # 推理用时
108
+ """
109
+
110
+ if image_rgb is None:
111
+ return None, None, [], [], ""
112
+
113
+ image_rgb_for_extract = image_rgb.copy()
114
+
115
+ start_time = time.time()
116
+
117
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
118
+
119
+ xyxy, conf, keypoints, scores = self.pred_detect_and_keypoints(image_bgr)
120
+
121
+ # 绘制棋盘框架
122
+ draw_image = self.det.draw_pred(image_rgb, xyxy, conf)
123
+
124
+ """
125
+ 绘制 原图关键点
126
+ """
127
+ original_image_with_keypoints = self.pose.draw_pred(img=draw_image, keypoints=keypoints, scores=scores)
128
+
129
+ transformed_image, cells_labels, scores = self.extract_chessboard_and_classifier_layout(image_rgb=image_rgb_for_extract, keypoints=keypoints)
130
+
131
+
132
+ use_time = time.time() - start_time
133
+
134
+ time_info = f"推理用时: {use_time:.2f}s"
135
+
136
+ return original_image_with_keypoints, transformed_image, cells_labels, scores, time_info
137
+
core/helper_34.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from typing import Tuple, List
4
+
5
+
6
+ BONE_NAMES = [
7
+ "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8",
8
+ "J0", "J1", "J2", "J3", "J4", "J5", "J6", "J7", "J8",
9
+ "B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
10
+ "B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
11
+ ]
12
+
13
+ def check_keypoints(keypoints: np.ndarray):
14
+ """
15
+ 检查关键点坐标是否正确
16
+ @param keypoints: 关键点坐标, shape 为 (34, 2)
17
+ """
18
+ if keypoints.shape != (34, 2):
19
+ raise Exception(f"keypoints shape error: {keypoints.shape}")
20
+
21
+
22
+ def build_cells_xywh_by_cronners(corner_points: np.ndarray, padding: int = 3) -> np.ndarray:
23
+ """
24
+ 根据 棋盘的 corner 点坐标 计算 每个位置的 xywh
25
+ @param corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2)
26
+ @param padding: 棋盘边框 padding
27
+
28
+ @return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
29
+ """
30
+
31
+ if corner_points.shape != (4, 2):
32
+ raise Exception(f"corner_points shape error: {corner_points.shape}")
33
+
34
+ top_left_xy = corner_points[0]
35
+ top_right_xy = corner_points[1]
36
+ bottom_left_xy = corner_points[2]
37
+ bottom_right_xy = corner_points[3]
38
+
39
+ # 计算 每个框的 w 和 h
40
+ item_w = (top_right_xy[0] - top_left_xy[0]) / (9 - 1)
41
+ item_h = (bottom_left_xy[1] - top_left_xy[1]) / (10 - 1)
42
+
43
+ item_w = item_w
44
+ item_h = item_h
45
+
46
+ item_w_with_padding = item_w - padding * 2
47
+ item_h_with_padding = item_h - padding * 2
48
+
49
+ # 计算 每个框的 center 坐标
50
+ cells_xywh = np.zeros((10, 9, 4))
51
+
52
+ for i in range(10):
53
+ for j in range(9):
54
+ center_x = top_left_xy[0] + item_w * j
55
+ center_y = top_left_xy[1] + item_h * i
56
+
57
+ cells_xywh[i, j] = [center_x, center_y, item_w_with_padding, item_h_with_padding]
58
+
59
+ return cells_xywh
60
+
61
+
62
+
63
+ # todo: 需要优化
64
+ def build_cells_xywh(keypoints: np.ndarray, width: int = 450, height: int = 500, padding: int = 3) -> np.ndarray:
65
+ """
66
+ @param keypoints: 关键点坐标, shape 为 (34, 2)
67
+ @param width: 棋盘宽度
68
+ @param height: 棋盘高度
69
+ @param padding: 棋盘边框 padding
70
+ @return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
71
+ """
72
+ check_keypoints(keypoints)
73
+
74
+
75
+ # 生成 A0 到 J8 的坐标, 如 B1 坐标 为 A1-J1 与 B0-B8 的交集点
76
+ cells_xywh = np.zeros((10, 9, 4), dtype=np.int16)
77
+
78
+ # 遍历 full_points 的每个点,计算其坐标
79
+ for i in range(10):
80
+ for j in range(9):
81
+ # 计算 第 i 行 第 j 列 的坐标
82
+ row_name = chr(ord('A') + i)
83
+ col_name = str(j)
84
+ flag_name = f"{row_name}{col_name}"
85
+ if flag_name in BONE_NAMES:
86
+ # 计算 第 i 行 第 j 列 的坐标
87
+ cur_xy = keypoints[BONE_NAMES.index(flag_name)]
88
+ cells_xywh[i, j] = [cur_xy[0], cur_xy[1], 0, 0]
89
+ else:
90
+ # 计算 第 i 行 第 j 列 的坐标
91
+ row_start_name = f"{row_name}0"
92
+ row_end_name = f"{row_name}8"
93
+
94
+ col_start_name = f"A{col_name}"
95
+ col_end_name = f"J{col_name}"
96
+
97
+ row_start_xy = keypoints[BONE_NAMES.index(row_start_name)]
98
+ row_end_xy = keypoints[BONE_NAMES.index(row_end_name)]
99
+
100
+ col_start_xy = keypoints[BONE_NAMES.index(col_start_name)]
101
+ col_end_xy = keypoints[BONE_NAMES.index(col_end_name)]
102
+
103
+ # 计算 row_start_xy 到 row_end_xy 的直线 与 col_start_xy 到 col_end_xy 的直线 的交点
104
+ # 使用参数方程法计算交点
105
+ x1, y1 = row_start_xy # 横向直线起点
106
+ x2, y2 = row_end_xy # 横向直线终点
107
+ x3, y3 = col_start_xy # 纵向直线起点
108
+ x4, y4 = col_end_xy # 纵向直线终点
109
+
110
+ # 计算交点坐标
111
+ # 使用克莱姆法则求解
112
+ denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
113
+
114
+ # 计算交点的 x 坐标
115
+ x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
116
+ # 计算交点的 y 坐标
117
+ y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
118
+
119
+ cells_xywh[i, j] = [int(x), int(y), 0, 0]
120
+
121
+ # 计算每个点位的 wh
122
+ for i in range(10):
123
+ for j in range(9):
124
+ cur_xy = cells_xywh[i, j]
125
+ # 获取上下左右 4 个点, 根据 4 个点计算 wh, 宽高为 4 个点 计算出来的 x1y1x2y2 的距离 的 1/2
126
+ if i == 0:
127
+ # [i+1, j] 的 反向点
128
+ up_xy = 2 * cur_xy - cells_xywh[i+1, j]
129
+ else:
130
+ up_xy = cells_xywh[i - 1, j]
131
+
132
+ if i == 9:
133
+ # [i-1, j] 的 反向点
134
+ down_xy = 2 * cur_xy - cells_xywh[i-1, j]
135
+ else:
136
+ down_xy = cells_xywh[i+1, j]
137
+
138
+ if j == 0:
139
+ left_xy = 2 * cur_xy - cells_xywh[i, j+1]
140
+ else:
141
+ left_xy = cells_xywh[i, j-1]
142
+
143
+ if j == 8:
144
+ right_xy = 2 * cur_xy - cells_xywh[i, j-1]
145
+ else:
146
+ right_xy = cells_xywh[i, j+1]
147
+
148
+ min_x = min(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
149
+ min_y = min(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
150
+
151
+ min_x += padding
152
+ min_y += padding
153
+
154
+ # 防止 min_x 和 min_y 为 0
155
+ min_x = max(min_x, 1)
156
+ min_y = max(min_y, 1)
157
+
158
+ max_x = max(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
159
+ max_y = max(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
160
+
161
+ max_x -= padding
162
+ max_y -= padding
163
+
164
+ # 防止 max_x 和 max_y 超出边界
165
+ max_x = min(max_x, width - 1)
166
+ max_y = min(max_y, height - 1)
167
+
168
+ w = (max_x - min_x) / 2
169
+ h = (max_y - min_y) / 2
170
+
171
+ cells_xywh[i, j] = [int(cur_xy[0]), int(cur_xy[1]), int(w), int(h)]
172
+
173
+ return cells_xywh
174
+
175
+
176
+ def perspective_transform(
177
+ image: cv2.UMat,
178
+ src_points: np.ndarray,
179
+ keypoints: np.ndarray,
180
+ dst_size=(450, 500)) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
181
+ """
182
+ 透视变换
183
+ @param image: 图片
184
+ @param src_points: 源点坐标
185
+ @param keypoints: 关键点坐标
186
+ @param dst_size: 目标尺寸 (width, height) 10 行 9 列
187
+
188
+ @return:
189
+ result: 透视变换后的图片
190
+ transformed_keypoints: 透视变换后的关键点坐标
191
+ corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
192
+ """
193
+
194
+ check_keypoints(keypoints)
195
+
196
+
197
+ # 源点和目标点
198
+ src = np.float32(src_points)
199
+ padding = 50
200
+ corner_points = np.float32([
201
+ # 左上角
202
+ [padding, padding],
203
+ # 右上角
204
+ [dst_size[0]-padding, padding],
205
+ # 左下角
206
+ [padding, dst_size[1]-padding],
207
+ # 右下角
208
+ [dst_size[0]-padding, dst_size[1]-padding]])
209
+
210
+ # 计算透视变换矩阵
211
+ matrix = cv2.getPerspectiveTransform(src, corner_points)
212
+
213
+ # 执行透视变换
214
+ result = cv2.warpPerspective(image, matrix, dst_size)
215
+
216
+ # 重塑数组为要求的格式 (N,1,2)
217
+ keypoints_reshaped = keypoints.reshape(-1, 1, 2).astype(np.float32)
218
+ transformed_keypoints = cv2.perspectiveTransform(keypoints_reshaped, matrix)
219
+ # 转回原来的形状
220
+ transformed_keypoints = transformed_keypoints.reshape(-1, 2)
221
+
222
+ return result, transformed_keypoints, corner_points
223
+
224
+
225
+
226
+ def get_board_corner_points(keypoints: np.ndarray) -> np.ndarray:
227
+ """
228
+ 计算棋局四个边角的 points
229
+ @param keypoints: 关键点坐标, shape 为 (34, 2)
230
+ @return: 边角的坐标, shape 为 (4, 2)
231
+ """
232
+ check_keypoints(keypoints)
233
+
234
+ # 找到 A0 A8 J0 J8 的坐标 以及 A4 和 J4 的坐标
235
+ a0_index = BONE_NAMES.index("A0")
236
+ a8_index = BONE_NAMES.index("A8")
237
+ j0_index = BONE_NAMES.index("J0")
238
+ j8_index = BONE_NAMES.index("J8")
239
+
240
+ a0_xy = keypoints[a0_index]
241
+ a8_xy = keypoints[a8_index]
242
+ j0_xy = keypoints[j0_index]
243
+ j8_xy = keypoints[j8_index]
244
+
245
+ # 计算新的四个角点坐标
246
+ dst_points = np.array([
247
+ a0_xy,
248
+ a8_xy,
249
+ j0_xy,
250
+ j8_xy
251
+ ], dtype=np.float32)
252
+
253
+ return dst_points
254
+
255
+ def extract_chessboard(img: cv2.UMat, keypoints: np.ndarray) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
256
+ """
257
+ 提取棋盘信息
258
+ @param img: 图片
259
+ @param keypoints: 关键点坐标, shape 为 (34, 2)
260
+ @return:
261
+ transformed_image: 透视变换后的图片
262
+ transformed_keypoints: 透视变换后的关键点坐标
263
+ transformed_corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
264
+ """
265
+
266
+ check_keypoints(keypoints)
267
+
268
+ source_corner_points = get_board_corner_points(keypoints)
269
+
270
+ transformed_image, transformed_keypoints, transformed_corner_points = perspective_transform(img, source_corner_points, keypoints)
271
+
272
+ return transformed_image, transformed_keypoints, transformed_corner_points
273
+
274
+
275
+ def collect_cells_images(image: cv2.UMat, cells_xywh: np.ndarray) -> List[List[np.ndarray]]:
276
+ """
277
+ 收集 棋盘的 cells_xywh 对应的图片集合
278
+ """
279
+ width = image.shape[1]
280
+ height = image.shape[0]
281
+ crop_cells: List[List[np.ndarray]] = []
282
+
283
+ for i in range(10):
284
+ row_cells = []
285
+ for j in range(9):
286
+ x, y, w, h = cells_xywh[i, j]
287
+
288
+ x_0 = max(int(x-w/2), 0)
289
+ y_0 = max(int(y-h/2), 0)
290
+ x_1 = min(int(x+w/2), width-1)
291
+ y_1 = min(int(y+h/2), height-1)
292
+
293
+ crop_img = image[y_0:y_1, x_0:x_1]
294
+ row_cells.append(crop_img)
295
+ crop_cells.append(row_cells)
296
+
297
+ return crop_cells
298
+
299
+ def draw_cells_box(image: cv2.UMat, cells_xywh: np.ndarray) -> cv2.UMat:
300
+ """
301
+ 绘制 棋盘的 cells_xywh 对应的 矩形框
302
+ """
303
+ width = image.shape[1]
304
+ height = image.shape[0]
305
+ for i in range(10):
306
+ for j in range(9):
307
+ x, y, w, h = cells_xywh[i, j]
308
+
309
+ x_0 = max(int(x-w/2), 0)
310
+ y_0 = max(int(y-h/2), 0)
311
+ x_1 = min(int(x+w/2), width-1)
312
+ y_1 = min(int(y+h/2), height-1)
313
+
314
+ cv2.rectangle(image,(x_0, y_0), (x_1, y_1), (0, 0, 255), 1)
315
+
316
+ return image
core/helper_cls.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ dict_cate_names = {
4
+ 'point': '.',
5
+ 'other': 'x',
6
+ 'red_king': 'K',
7
+ 'red_advisor': 'A',
8
+ 'red_bishop': 'B',
9
+ 'red_knight': 'N',
10
+ 'red_rook': 'R',
11
+ 'red_cannon': 'C',
12
+ 'red_pawn': 'P',
13
+ 'black_king': 'k',
14
+ 'black_advisor': 'a',
15
+ 'black_bishop': 'b',
16
+ 'black_knight': 'n',
17
+ 'black_rook': 'r',
18
+ 'black_cannon': 'c',
19
+ 'black_pawn': 'p',
20
+ }
21
+
22
+
23
+ full_cate_names = list(dict_cate_names.keys())
core/kpt_34_with_xanything.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import json
4
+ import numpy as np
5
+
6
+ from .helper_34 import BONE_NAMES
7
+
8
+ class Shape:
9
+
10
+ @staticmethod
11
+ def init_from_dict(data: dict):
12
+ shape_ins = Shape(data["label"], data["points"], data["group_id"], data["shape_type"])
13
+
14
+ return shape_ins
15
+
16
+ def __init__(self, label="", points=None, group_id=1, shape_type=""):
17
+ self.label = label
18
+ self.score = None
19
+ self.points = points
20
+ self.group_id = group_id
21
+ self.description = ""
22
+ self.difficult = False
23
+ self.shape_type = shape_type
24
+ self.flags = {}
25
+ self.attributes = {}
26
+
27
+ def to_dict(self):
28
+ return {
29
+ "label": self.label,
30
+ "score": self.score,
31
+ "points": self.points,
32
+ "group_id": self.group_id,
33
+ "description": self.description,
34
+ "difficult": self.difficult,
35
+ "shape_type": self.shape_type,
36
+ "flags": self.flags,
37
+ "attributes": self.attributes
38
+ }
39
+
40
+ class KeyPoint(Shape):
41
+ def __init__(self, label="", point_xy=list[float, float], group_id=1):
42
+ # 校验 point_xy 是否为 2 个元素的列表
43
+ if len(point_xy) != 2:
44
+ raise ValueError("point_xy 必须是一个包含 2 个元素的列表")
45
+ super().__init__(label, [point_xy], group_id, "point")
46
+
47
+ class Rectangle(Shape):
48
+ def __init__(self, label="A1", xyxy=list[float, float, float, float], group_id=1):
49
+
50
+ if len(xyxy) != 4:
51
+ raise ValueError("xyxy 必须是一个包含 4 个元素的列表")
52
+
53
+ """
54
+ bbox [左上角坐标, 右上角坐标, 右下角坐标, 左下角坐标] [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
55
+ """
56
+ x1, y1, x2, y2 = xyxy
57
+
58
+ bbox = [
59
+ [x1, y1],
60
+ [x2, y1],
61
+ [x2, y2],
62
+ [x1, y2]
63
+ ]
64
+
65
+ super().__init__(label, bbox, group_id, "rectangle")
66
+
67
+ class Annotation:
68
+ @staticmethod
69
+ def init_from_dict(data: dict):
70
+ """
71
+ 从 dict 初始化 Annotation 类
72
+ """
73
+
74
+ image_height = data["imageHeight"]
75
+ image_width = data["imageWidth"]
76
+
77
+ ann = Annotation(image_path=data["imagePath"], image_width=image_width, image_height=image_height)
78
+
79
+ for shape in data["shapes"]:
80
+ if shape["shape_type"] == "rectangle":
81
+ ann.add_shape(Rectangle.init_from_dict(shape))
82
+ elif shape["shape_type"] == "point":
83
+ ann.add_shape(KeyPoint.init_from_dict(shape))
84
+
85
+ return ann
86
+
87
+ def __init__(self, image_path="", image_width=-1, image_height=-1):
88
+ self.version = "2.4.4"
89
+ self.flags = {}
90
+ self.shapes = []
91
+ self.image_data = None
92
+
93
+ self.image_path = image_path
94
+ self.image_height = image_height
95
+ self.image_width = image_width
96
+
97
+ def add_shape(self, shape: Rectangle | KeyPoint):
98
+ self.shapes.append(shape.to_dict())
99
+
100
+ def to_dict(self):
101
+ if self.image_path == "":
102
+ raise ValueError("image_path 不能为空")
103
+ if self.image_height == -1 or self.image_width == -1:
104
+ raise ValueError("image_height 和 image_width 不能为 -1")
105
+
106
+ return {
107
+ "version": self.version,
108
+ "flags": self.flags,
109
+ "shapes": self.shapes,
110
+ "imagePath": self.image_path,
111
+ "imageData": self.image_data,
112
+ "imageHeight": self.image_height,
113
+ "imageWidth": self.image_width
114
+ }
115
+
116
+
117
+ def save_kpt_34_with_xanything(image_input: np.ndarray, image_ann_path, bbox: list[float, float, float, float], kpt_34: list[tuple[str, float, float]], save_dir: str):
118
+ """
119
+ 保存 34 个关键点 和 一个 bbox 到 xanything 的 json 文件
120
+ """
121
+ x1, y1, x2, y2 = bbox
122
+ x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
123
+
124
+
125
+ if image_input is None:
126
+ raise ValueError("image_input 不能为 None")
127
+
128
+ image_height, image_width = image_input.shape[:2]
129
+
130
+
131
+ # image_ann_path 缺省 .json
132
+ if not image_ann_path.endswith(".json"):
133
+ image_ann_path = image_ann_path + ".json"
134
+
135
+ # 读取 image_ann_path 的 文件名
136
+ file_name = os.path.basename(image_ann_path)
137
+
138
+ annotation = Annotation(file_name, image_width, image_height)
139
+
140
+ kpt_34_dict = {}
141
+ for bone_name, x, y in kpt_34:
142
+ kpt_34_dict[bone_name] = [float(x), float(y)]
143
+
144
+ for bone_name in BONE_NAMES:
145
+ x, y = kpt_34_dict[bone_name]
146
+ annotation.add_shape(KeyPoint(bone_name, [x, y]))
147
+
148
+ # 添加 bbox
149
+ annotation.add_shape(Rectangle("bbox", [x1, y1, x2, y2]))
150
+
151
+ ann_file_path = os.path.join(save_dir, file_name)
152
+ # 保存
153
+ with open(ann_file_path, "w") as f:
154
+ json.dump(annotation.to_dict(), f)
155
+
156
+ # 保存图片
157
+ image_input_rgb = image_input.copy()[:, :, ::-1]
158
+
159
+ # print('ann_file_path:', ann_file_path.replace(".json", ".jpg"))
160
+
161
+ cv2.imwrite(ann_file_path.replace(".json", ".jpg"), image_input_rgb)
162
+
163
+ def read_xanything_to_json(json_path) -> tuple[list[tuple[str, float, float]], list[float, float, float, float]]:
164
+ """
165
+ 读取 xanything 的 json 文件
166
+ """
167
+ data = {}
168
+ with open(json_path, "r") as f:
169
+ data = json.load(f)
170
+
171
+ # data
172
+ annotation = Annotation.init_from_dict(data)
173
+
174
+ keypoints_34_dict: dict[str, list[float, float]] = {}
175
+ # x1, y1, x2, y2
176
+ bbox: list[float, float, float, float] = []
177
+
178
+ for shape in annotation.shapes:
179
+ if shape["shape_type"] == "point":
180
+ keypoints_34_dict[shape["label"]] = [shape["points"][0][0], shape["points"][0][1]]
181
+ elif shape["shape_type"] == "rectangle":
182
+ bbox = [shape["points"][0][0], shape["points"][0][1], shape["points"][2][0], shape["points"][2][1]]
183
+
184
+ keypoints_34: list[tuple[str, float, float]] = []
185
+
186
+ for item in BONE_NAMES:
187
+ keypoints_34.append((item, keypoints_34_dict[item][0], keypoints_34_dict[item][1]))
188
+
189
+ return keypoints_34, bbox
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
core/runonnx/base_onnx.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Tuple, Union, List
6
+
7
+ class BaseONNX(ABC):
8
+ def __init__(self, model_path: str, input_size: Tuple[int, int]):
9
+ """初始化ONNX模型基类
10
+
11
+ Args:
12
+ model_path (str): ONNX模型路径
13
+ input_size (tuple): 模型输入尺寸 (width, height)
14
+ """
15
+ self.session = onnxruntime.InferenceSession(model_path)
16
+ self.input_name = self.session.get_inputs()[0].name
17
+ self.input_size = input_size
18
+
19
+ def load_image(self, image: Union[cv2.UMat, str]) -> cv2.UMat:
20
+ """加载图像
21
+
22
+ Args:
23
+ image (Union[cv2.UMat, str]): 图像路径或cv2图像对象
24
+
25
+ Returns:
26
+ cv2.UMat: 加载的图像
27
+ """
28
+ if isinstance(image, str):
29
+ return cv2.imread(image)
30
+ return image.copy()
31
+
32
+ @abstractmethod
33
+ def preprocess_image(self, img_bgr: cv2.UMat, *args, **kwargs) -> np.ndarray:
34
+ """图像预处理抽象方法
35
+
36
+ Args:
37
+ img_bgr (cv2.UMat): BGR格式的输入图像
38
+
39
+ Returns:
40
+ np.ndarray: 预处理后的图像
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def run_inference(self, image: np.ndarray) -> Any:
46
+ """运行推理的抽象方法
47
+
48
+ Args:
49
+ image (np.ndarray): 预处理后的输入图像
50
+
51
+ Returns:
52
+ Any: 模型输出结果
53
+ """
54
+ pass
55
+
56
+ @abstractmethod
57
+ def pred(self, image: Union[cv2.UMat, str], *args, **kwargs) -> Any:
58
+ """预测的抽象方法
59
+
60
+ Args:
61
+ image (Union[cv2.UMat, str]): 输入图像或图像路径
62
+
63
+ Returns:
64
+ Any: 预测结果
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def draw_pred(self, img: cv2.UMat, *args, **kwargs) -> cv2.UMat:
70
+ """绘制预测结果的抽象方法
71
+
72
+ Args:
73
+ img (cv2.UMat): 要绘制的图像
74
+
75
+ Returns:
76
+ cv2.UMat: 绘制结果后的图像
77
+ """
78
+ pass
79
+
80
+
81
+ def check_images_list(self, images: List[Union[cv2.UMat, str, np.ndarray]]):
82
+ """
83
+ 检查图像列表是否有效
84
+ """
85
+ for image in images:
86
+ if not isinstance(image, cv2.UMat) and not isinstance(image, str) and not isinstance(image, np.ndarray):
87
+ raise ValueError("The images must be a list of cv2.UMat or str or np.ndarray.")
88
+
core/runonnx/full_classifier.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ import cv2
4
+ from typing import Tuple, List, Union
5
+ from .base_onnx import BaseONNX
6
+
7
+
8
+ def center_crop(image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
9
+ """
10
+ Center crop the image to the target size.
11
+
12
+ Args:
13
+ image (np.ndarray): The input image.
14
+ target_size (Tuple[int, int]): The desired output size (height, width).
15
+
16
+ Returns:
17
+ np.ndarray: The cropped image.
18
+ """
19
+ h, w, _ = image.shape
20
+ target_w, target_h = target_size
21
+
22
+ center_x = w // 2
23
+ center_y = h // 2
24
+
25
+ start_x = int(center_x - target_w // 2)
26
+ start_y = int(center_y - target_h // 2)
27
+
28
+ cropped_image = image[start_y:start_y + target_h, start_x:start_x + target_w]
29
+
30
+ return cropped_image
31
+
32
+
33
+ dict_cate_names = {
34
+ 'point': '.',
35
+ 'other': 'x',
36
+ 'red_king': 'K',
37
+ 'red_advisor': 'A',
38
+ 'red_bishop': 'B',
39
+ 'red_knight': 'N',
40
+ 'red_rook': 'R',
41
+ 'red_cannon': 'C',
42
+ 'red_pawn': 'P',
43
+ 'black_king': 'k',
44
+ 'black_advisor': 'a',
45
+ 'black_bishop': 'b',
46
+ 'black_knight': 'n',
47
+ 'black_rook': 'r',
48
+ 'black_cannon': 'c',
49
+ 'black_pawn': 'p',
50
+ }
51
+
52
+ class FULL_CLASSIFIER_ONNX(BaseONNX):
53
+
54
+ label_2_short = dict_cate_names
55
+
56
+ classes_labels = list(dict_cate_names.keys())
57
+
58
+ def __init__(self,
59
+ model_path,
60
+ # 输入图片大小
61
+ input_size=(280, 315), # (w, h)
62
+ # 图片裁剪大小
63
+ crop_size=(400, 450), # (w, h)
64
+ ):
65
+ super().__init__(model_path, input_size)
66
+
67
+ self.crop_size = crop_size
68
+
69
+
70
+ def preprocess_image(self, img_bgr: cv2.UMat, is_rgb: bool = True):
71
+
72
+ if not is_rgb:
73
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
74
+ else:
75
+ img_rgb = img_bgr
76
+
77
+ if img_rgb.shape[:2] != self.crop_size:
78
+ # 调整图片大小 执行 center crop
79
+ img_rgb = center_crop(img_rgb, self.crop_size) # dst_size = (w, h)
80
+
81
+ # resize 到 input_size
82
+ img_rgb = cv2.resize(img_rgb, self.input_size)
83
+
84
+ # normalize mean and std
85
+ img = (img_rgb - np.array([ 123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])
86
+
87
+ img = img.astype(np.float32)
88
+ # 转换为浮点型并归一化
89
+ # img = img.astype(np.float32) / 255.0
90
+
91
+ # 调整维度顺序 (H,W,C) -> (C,H,W)
92
+ img = np.transpose(img, (2, 0, 1))
93
+
94
+ # 添加 batch 维度
95
+ img = np.expand_dims(img, axis=0)
96
+
97
+ return img
98
+
99
+
100
+ def run_inference(self, image: np.ndarray) -> np.ndarray:
101
+ """
102
+ Run inference on the image.
103
+
104
+ Args:
105
+ image (np.ndarray): The image to run inference on.
106
+
107
+ Returns:
108
+ tuple: A tuple containing the detection results and labels.
109
+ """
110
+ # 运行推理
111
+ outputs, = self.session.run(None, {self.input_name: image})
112
+
113
+ return outputs
114
+
115
+ def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[List[str]], List[List[str]], List[List[float]], str]:
116
+ """
117
+ Predict the detection results of the image.
118
+
119
+ Args:
120
+ image (cv2.UMat, str): The image to predict.
121
+
122
+ Returns:
123
+
124
+ """
125
+ if isinstance(image, str):
126
+ img_bgr = cv2.imread(image)
127
+ is_rgb = False
128
+ else:
129
+ img_bgr = image.copy()
130
+
131
+ image = self.preprocess_image(img_bgr, is_rgb)
132
+
133
+ labels = self.run_inference(image)
134
+
135
+ # 校验 labels 的 shape
136
+ assert labels.shape[1:] == (90, 16)
137
+
138
+ # shape (90, 16)
139
+ first_batch_labels = labels[0]
140
+
141
+ # 获取置信度最高的标签
142
+ # list[int]
143
+ label_indexes = np.argmax(first_batch_labels, axis=-1).tolist()
144
+
145
+ # 将标签索引转换为标签
146
+ # list[str]
147
+ label_names = [self.classes_labels[index] for index in label_indexes]
148
+
149
+ # list[str]
150
+ label_short = [self.label_2_short[name] for name in label_names]
151
+
152
+ # 获取置信度, 根据 first_batch_labels 和 label_indexes
153
+ confidence = first_batch_labels[np.arange(first_batch_labels.shape[0]), label_indexes]
154
+
155
+ label_names_10x9 = [label_names[i*9:(i+1)*9] for i in range(10)]
156
+ label_short_10x9 = [label_short[i*9:(i+1)*9] for i in range(10)]
157
+ confidence_10x9 = [confidence[i*9:(i+1)*9] for i in range(10)]
158
+
159
+
160
+ layout_str = "\n".join(["".join(row) for row in label_short_10x9])
161
+
162
+ return label_names_10x9, label_short_10x9, confidence_10x9, layout_str
163
+
164
+ def draw_pred(self, image: cv2.UMat, label_index: int, label_name: str, label_short: str, confidence: float) -> cv2.UMat:
165
+
166
+ # 在图像上绘制预测结果
167
+ cv2.putText(image, f"{label_short} {confidence:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
168
+
169
+ return image
170
+
171
+
172
+ def draw_pred_with_result(self, image: cv2.UMat, results: List[Tuple[int, str, str, float]], cells_xyxy: np.ndarray, is_rgb: bool = True) -> cv2.UMat:
173
+
174
+ assert len(results) == cells_xyxy.shape[0]
175
+
176
+ if not is_rgb:
177
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
178
+
179
+ for i, (label_index, label_name, label_short, confidence) in enumerate(results):
180
+ # 确保坐标是整数类型
181
+ x1, y1, x2, y2 = map(int, cells_xyxy[i])
182
+
183
+ if label_name.startswith('red'):
184
+ color = (180, 105, 255) # 粉红色
185
+ elif label_name.startswith('black'):
186
+ color = (0, 100, 50) # 黑色
187
+ else:
188
+ color = (0, 0, 255) # 蓝色
189
+
190
+
191
+ if confidence < 0.5:
192
+ # yellow
193
+ color = (255, 255, 0)
194
+
195
+ # confidence:.2f 仅保留两位小数 移除
196
+
197
+ label_str = f"{label_short} {confidence:.2f}" if confidence < 0.9 else f"{label_short}"
198
+
199
+ cv2.putText(image, label_str, (x1 + 8, y2 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
200
+
201
+ return image
202
+
203
+
core/runonnx/rtmdet.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import cv2
5
+ from typing import Tuple, List, Union
6
+ from .base_onnx import BaseONNX
7
+
8
+
9
+ class RTMDET_ONNX(BaseONNX):
10
+
11
+ def __init__(self, model_path, input_size=(640, 640)):
12
+ super().__init__(model_path, input_size)
13
+
14
+ def preprocess_image(self, img_bgr: cv2.UMat):
15
+ # 调整图片大小
16
+ img_bgr = cv2.resize(img_bgr, self.input_size)
17
+
18
+ # normalize mean and std
19
+ img = (img_bgr - np.array([103.53, 116.28, 123.675])) / np.array([57.375, 57.12, 58.395])
20
+
21
+ img = img.astype(np.float32)
22
+ # 转换为浮点型并归一化
23
+ # img = img.astype(np.float32) / 255.0
24
+
25
+ # 调整维度顺序 (H,W,C) -> (C,H,W)
26
+ img = np.transpose(img, (2, 0, 1))
27
+
28
+ # 添加 batch 维度
29
+ img = np.expand_dims(img, axis=0)
30
+
31
+ return img
32
+
33
+
34
+ def run_inference(self, image: np.ndarray):
35
+ """
36
+ Run inference on the image.
37
+
38
+ Args:
39
+ image (np.ndarray): The image to run inference on.
40
+
41
+ Returns:
42
+ tuple: A tuple containing the detection results and labels.
43
+ """
44
+ # 运行推理
45
+ outputs = self.session.run(None, {self.input_name: image})
46
+
47
+ """
48
+ dets: 检测框 [batch, num_dets, [x1, y1, x2, y2, conf]] ([batch, num_dets, Reshape(dets_dim_2)])
49
+ labels: 标签 [batch,num_dets]
50
+ """
51
+ dets, labels = outputs
52
+
53
+ return dets, labels
54
+
55
+ def pred(self, image: List[Union[cv2.UMat, str]]) -> Tuple[List[int], float]:
56
+ """
57
+ Predict the detection results of the image.
58
+
59
+ Args:
60
+ image (cv2.UMat, str): The image to predict.
61
+
62
+ Returns:
63
+ xyxy (list[int, int, int, int]): The detection results.
64
+ conf (float): The confidence of the detection results.
65
+ """
66
+ if isinstance(image, str):
67
+ img_bgr = cv2.imread(image)
68
+ else:
69
+ img_bgr = image.copy()
70
+
71
+ original_w, original_h = img_bgr.shape[1], img_bgr.shape[0]
72
+
73
+ image = self.preprocess_image(img_bgr)
74
+ dets, labels = self.run_inference(image)
75
+
76
+ # 获取置信度最高的检测框
77
+ # dets = dets[0][0]
78
+ # labels = labels[0][0]
79
+
80
+ x1, y1, x2, y2, conf = dets[0][0]
81
+
82
+ xyxy = [x1, y1, x2, y2]
83
+
84
+ xyxy = self.transform_xyxy_to_original(xyxy, original_w, original_h)
85
+
86
+ return xyxy, conf
87
+
88
+ def transform_xyxy_to_original(self, xyxy, original_w, original_h) -> List[int]:
89
+ """
90
+ 将检测框从输入图像的尺寸转换为原始图像的尺寸
91
+ """
92
+ x1, y1, x2, y2 = xyxy
93
+
94
+ input_w, input_h = self.input_size
95
+ ratio_w, ratio_h = original_w / input_w, original_h / input_h
96
+
97
+ x1, y1, x2, y2 = x1 * ratio_w, y1 * ratio_h, x2 * ratio_w, y2 * ratio_h
98
+ # 转换为整数
99
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
100
+
101
+ return [x1, y1, x2, y2]
102
+
103
+ def draw_pred(self, img: cv2.UMat, xyxy: List[int], conf: float, is_rgb: bool = True) -> cv2.UMat:
104
+ """
105
+ Draw the detection results on the image.
106
+ """
107
+
108
+ if not is_rgb:
109
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
+
111
+ x1, y1, x2, y2 = xyxy
112
+
113
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
114
+ cv2.putText(img, f"{conf:.2f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
115
+
116
+ return img
117
+
core/runonnx/rtmpose.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from typing import Tuple, List, Union
4
+ from .base_onnx import BaseONNX
5
+
6
+ class RTMPOSE_ONNX(BaseONNX):
7
+
8
+ bone_names = [
9
+ "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8",
10
+ "J0", "J1", "J2", "J3", "J4", "J5", "J6", "J7", "J8",
11
+ "B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
12
+ "B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
13
+ ]
14
+
15
+ def __init__(self, model_path, input_size=(256, 256), padding=1.25):
16
+ super().__init__(model_path, input_size)
17
+ self.padding = padding
18
+
19
+
20
+ def get_bbox_center_scale(self, bbox: List[int]):
21
+ """Convert bounding box to center and scale.
22
+
23
+ The center is the coordinates of the bbox center, and the scale is the
24
+ bbox width and height normalized by the padding factor.
25
+
26
+ Args:
27
+ bbox: Bounding box in format [x1, y1, x2, y2]
28
+
29
+ Returns:
30
+ tuple: A tuple containing:
31
+ - center (numpy.ndarray): Center coordinates [x, y]
32
+ - scale (numpy.ndarray): Scale [width, height]
33
+ """
34
+
35
+ # Get bbox center
36
+ x1, y1, x2, y2 = bbox
37
+ center = np.array([(x1 + x2) / 2.0, (y1 + y2) / 2.0])
38
+
39
+ # Get bbox scale (width and height)
40
+ w = x2 - x1
41
+ h = y2 - y1
42
+
43
+ # Convert to scaled width/height
44
+ scale = np.array([w, h]) * self.padding
45
+
46
+ return center, scale
47
+
48
+
49
+ @staticmethod
50
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
51
+ """Rotate a point by an angle.
52
+
53
+ Args:
54
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
55
+ angle_rad (float): rotation angle in radian
56
+
57
+ Returns:
58
+ np.ndarray: Rotated point in shape (2, )
59
+ """
60
+
61
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
62
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
63
+ return rot_mat @ pt
64
+
65
+
66
+ @staticmethod
67
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray):
68
+ """To calculate the affine matrix, three pairs of points are required. This
69
+ function is used to get the 3rd point, given 2D points a & b.
70
+
71
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
72
+ anticlockwise, using b as the rotation center.
73
+
74
+ Args:
75
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
76
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
77
+
78
+ Returns:
79
+ np.ndarray: The 3rd point.
80
+ """
81
+ direction = a - b
82
+ c = b + np.r_[-direction[1], direction[0]]
83
+ return c
84
+
85
+
86
+ @staticmethod
87
+ def get_warp_matrix(
88
+ center: np.ndarray,
89
+ scale: np.ndarray,
90
+ rot: float,
91
+ output_size: Tuple[int, int],
92
+ shift: Tuple[float, float] = (0., 0.),
93
+ inv: bool = False,
94
+ fix_aspect_ratio: bool = True,
95
+ ) -> np.ndarray:
96
+ """Calculate the affine transformation matrix that can warp the bbox area
97
+ in the input image to the output size.
98
+
99
+ Args:
100
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
101
+ scale (np.ndarray[2, ]): Scale of the bounding box
102
+ wrt [width, height].
103
+ rot (float): Rotation angle (degree).
104
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
105
+ destination heatmaps.
106
+ shift (0-100%): Shift translation ratio wrt the width/height.
107
+ Default (0., 0.).
108
+ inv (bool): Option to inverse the affine transform direction.
109
+ (inv=False: src->dst or inv=True: dst->src)
110
+ fix_aspect_ratio (bool): Whether to fix aspect ratio during transform.
111
+ Defaults to True.
112
+
113
+ Returns:
114
+ np.ndarray: A 2x3 transformation matrix
115
+ """
116
+ assert len(center) == 2
117
+ assert len(scale) == 2
118
+ assert len(output_size) == 2
119
+ assert len(shift) == 2
120
+
121
+ shift = np.array(shift)
122
+ src_w, src_h = scale[:2]
123
+ dst_w, dst_h = output_size[:2]
124
+
125
+ rot_rad = np.deg2rad(rot)
126
+ src_dir = RTMPOSE_ONNX._rotate_point(np.array([src_w * -0.5, 0.]), rot_rad)
127
+ dst_dir = np.array([dst_w * -0.5, 0.])
128
+
129
+ src = np.zeros((3, 2), dtype=np.float32)
130
+ src[0, :] = center + scale * shift
131
+ src[1, :] = center + src_dir + scale * shift
132
+
133
+ dst = np.zeros((3, 2), dtype=np.float32)
134
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
135
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
136
+
137
+ if fix_aspect_ratio:
138
+ src[2, :] = RTMPOSE_ONNX._get_3rd_point(src[0, :], src[1, :])
139
+ dst[2, :] = RTMPOSE_ONNX._get_3rd_point(dst[0, :], dst[1, :])
140
+ else:
141
+ src_dir_2 = RTMPOSE_ONNX._rotate_point(np.array([0., src_h * -0.5]), rot_rad)
142
+ dst_dir_2 = np.array([0., dst_h * -0.5])
143
+ src[2, :] = center + src_dir_2 + scale * shift
144
+ dst[2, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir_2
145
+
146
+ if inv:
147
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
148
+ else:
149
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
150
+ return warp_mat
151
+
152
+
153
+ def get_warp_size_with_input_size(self,
154
+ bbox_center: List[int],
155
+ bbox_scale: List[int],
156
+ inv: bool = False,
157
+ ):
158
+ """
159
+ 获取仿射变换矩阵的输出尺寸
160
+ """
161
+
162
+ w, h = self.input_size
163
+ warp_size = self.input_size
164
+
165
+ # 修正长宽比
166
+ scale_w, scale_h = bbox_scale
167
+ aspect_ratio = w / h
168
+ if scale_w > scale_h * aspect_ratio:
169
+ bbox_scale = [scale_w, scale_w / aspect_ratio]
170
+ else:
171
+ bbox_scale = [scale_h * aspect_ratio, scale_h]
172
+
173
+ # 计算仿射变换矩阵 确保数据类型正确
174
+ center = np.array(bbox_center, dtype=np.float32)
175
+ scale = np.array(bbox_scale, dtype=np.float32)
176
+
177
+ rot = 0.0 # 不考虑旋转
178
+
179
+ warp_mat = self.get_warp_matrix(center, scale, rot, output_size=warp_size, inv=inv)
180
+
181
+ return warp_mat
182
+
183
+ def topdown_affine(self, img: cv2.UMat, bbox_center: List[int], bbox_scale: List[int]):
184
+ """简化版的 top-down 仿射变换函数
185
+
186
+ Args:
187
+ img: 输入图像
188
+
189
+ Returns:
190
+ 变换后的图像
191
+ """
192
+
193
+ warp_mat = self.get_warp_size_with_input_size(bbox_center, bbox_scale)
194
+
195
+ # 应用仿射变换
196
+ dst_img = cv2.warpAffine(img, warp_mat, self.input_size, flags=cv2.INTER_LINEAR)
197
+
198
+ return dst_img
199
+
200
+
201
+ # 获取每个关键点的最优预测位置
202
+ def get_simcc_maximum(self, simcc_x, simcc_y):
203
+
204
+ # 在最后一维上找到最大值的索引
205
+ x_indices = np.argmax(simcc_x[0], axis=1) # (34,)
206
+ y_indices = np.argmax(simcc_y[0], axis=1) # (34,)
207
+
208
+
209
+ input_w, input_h = self.input_size
210
+
211
+ # 将索引转换为实际坐标 (0-1之间)
212
+ x_coords = x_indices / (input_w * 2) # 归一化到0-1
213
+ y_coords = y_indices / (input_h * 2)
214
+
215
+ # 组合成坐标对
216
+ keypoints = np.stack([x_coords, y_coords], axis=1) # (34, 2)
217
+
218
+ # 获取每个点的置信度分数
219
+ scores = np.max(simcc_x[0], axis=1) * np.max(simcc_y[0], axis=1)
220
+
221
+ return keypoints, scores
222
+
223
+
224
+
225
+ def preprocess_image(self, img_bgr: cv2.UMat, bbox_center: List[int], bbox_scale: List[int]):
226
+
227
+ """
228
+ 预处理图像
229
+
230
+ Args:
231
+ img_bgr (cv2.UMat): 输入图像
232
+ bbox_center (list[int, int]): 边界框中心坐标 [x, y]
233
+ bbox_scale (list[int, int]): 边界框尺度 [w, h]
234
+ """
235
+
236
+ affine_img_bgr = self.topdown_affine(img_bgr, bbox_center, bbox_scale)
237
+
238
+ # 转RGB并进行归一化
239
+ affine_img_rgb = cv2.cvtColor(affine_img_bgr, cv2.COLOR_BGR2RGB)
240
+ # normalize mean and std
241
+ affine_img_rgb_norm = (affine_img_rgb - np.array([123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])
242
+ # 转换为浮点型并归一化
243
+ img = affine_img_rgb_norm.astype(np.float32)
244
+ # 调整维度顺序 (H,W,C) -> (C,H,W)
245
+ img = np.transpose(img, (2, 0, 1))
246
+ # 添加 batch 维度
247
+ img = np.expand_dims(img, axis=0)
248
+
249
+ return img
250
+
251
+
252
+ def run_inference(self, image: np.ndarray):
253
+ """
254
+ Run inference on the image.
255
+
256
+ Args:
257
+ image (np.ndarray): The image to run inference on.
258
+
259
+ Returns:
260
+ tuple: A tuple containing the detection results and labels.
261
+ """
262
+ # 运行推理
263
+ outputs = self.session.run(None, {self.input_name: image})
264
+ """
265
+ simcc_x: float32[batch,MatMulsimcc_x_dim_1,512]
266
+ simcc_y: float32[batch,MatMulsimcc_x_dim_1,512]
267
+ """
268
+ simcc_x, simcc_y = outputs
269
+
270
+ return simcc_x, simcc_y
271
+
272
+ def pred(self, image: List[Union[cv2.UMat, str]], bbox: List[int]) -> Tuple[np.ndarray, np.ndarray]:
273
+ """
274
+ Predict the keypoints results of the image.
275
+
276
+ Args:
277
+ image (str | cv2.UMat): The image to predict.
278
+ bbox (list[int, int, int, int]): The bounding box to predict.
279
+
280
+ Returns:
281
+ keypoints (np.ndarray): The predicted keypoints.
282
+ scores (np.ndarray): The predicted scores.
283
+ """
284
+ if isinstance(image, str):
285
+ img_bgr = cv2.imread(image)
286
+ else:
287
+ img_bgr = image.copy()
288
+
289
+ bbox_center, bbox_scale = self.get_bbox_center_scale(bbox)
290
+
291
+ image = self.preprocess_image(img_bgr, bbox_center, bbox_scale)
292
+ simcc_x, simcc_y = self.run_inference(image)
293
+
294
+ # 获取SimCC预测的最大值位置,返回关键点坐标��置信度分数
295
+ # 对应 width 和 height 为 input_size 的归一化,即 (256,256)
296
+ keypoints, scores = self.get_simcc_maximum(simcc_x, simcc_y)
297
+
298
+ # 将预测的关键点坐标从模型输出尺寸映射回原图尺寸
299
+ keypoints = self.transform_keypoints_to_original(keypoints, bbox_center, bbox_scale, self.input_size)
300
+
301
+ return keypoints, scores
302
+
303
+ def transform_keypoints_to_original(self, keypoints, center, scale, output_size):
304
+ """
305
+ 将预测的关键点坐标从模型输出尺寸映射回原图尺寸
306
+
307
+ Args:
308
+ keypoints: 预测的关键点坐标 [N, 2]
309
+ center: bbox中心点 [x, y]
310
+ scale: bbox尺度 [w, h]
311
+ output_size: 模型输入尺寸 (w, h)
312
+
313
+ Returns:
314
+ np.ndarray: 转换后的关键点坐标 [N, 2]
315
+ """
316
+ target_coords = keypoints.copy()
317
+
318
+ # 将0-1的预测坐标转换为像素坐标, 256*256
319
+ target_coords[:, 0] = target_coords[:, 0] * output_size[0]
320
+ target_coords[:, 1] = target_coords[:, 1] * output_size[1]
321
+
322
+ # 计算仿射变换矩阵
323
+ warp_mat = self.get_warp_size_with_input_size(center, scale, inv=True)
324
+
325
+ # 转换为齐次坐标
326
+ ones = np.ones((len(target_coords), 1))
327
+ target_coords_homogeneous = np.hstack([target_coords, ones])
328
+
329
+ # 应用逆变换
330
+ original_keypoints = target_coords_homogeneous @ warp_mat.T
331
+
332
+ return original_keypoints
333
+
334
+ def draw_pred(self, img: cv2.UMat, keypoints: np.ndarray, scores: np.ndarray, is_rgb: bool = True) -> cv2.UMat:
335
+ """
336
+ Draw the keypoints results on the image.
337
+ """
338
+
339
+ if not is_rgb:
340
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
341
+
342
+ # 获取 随机的 34 中颜色
343
+ colors = np.random.randint(0, 256, (34, 3))
344
+
345
+ for i, (point, score) in enumerate(zip(keypoints, scores)):
346
+ if score > 0.3: # 设置置信度阈值
347
+ x, y = map(int, point)
348
+ # 使用不同颜色标注不同的关键点
349
+ color = colors[i]
350
+
351
+ cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
352
+ # 添加关键点索引标注
353
+ cv2.putText(img, self.bone_names[i], (x+5, y+5),
354
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
355
+ return img
356
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ opencv-python