yolo12138 commited on
Commit
085b115
·
1 Parent(s): 9316eb4

feat: kpt 4

Browse files
HISTORY.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ### 2024-12-28
2
+
3
+ 1. 使用 4 个关键点检测
4
+
app.py CHANGED
@@ -5,7 +5,7 @@ from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
  det_model_path="onnx/det/v1.onnx",
8
- pose_model_path="onnx/pose/v1.onnx",
9
  full_classifier_model_path="onnx/layout_recognition/v1.onnx"
10
  )
11
 
 
5
 
6
  detector = ChessboardDetector(
7
  det_model_path="onnx/det/v1.onnx",
8
+ pose_model_path="onnx/pose/4_v2.onnx",
9
  full_classifier_model_path="onnx/layout_recognition/v1.onnx"
10
  )
11
 
core/chessboard_detector.py CHANGED
@@ -8,7 +8,7 @@ from .runonnx.rtmdet import RTMDET_ONNX
8
  from .runonnx.rtmpose import RTMPOSE_ONNX
9
  from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
10
 
11
- from core.helper_34 import extract_chessboard
12
 
13
  class ChessboardDetector:
14
  def __init__(self,
 
8
  from .runonnx.rtmpose import RTMPOSE_ONNX
9
  from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
10
 
11
+ from core.helper_4_kpt import extract_chessboard
12
 
13
  class ChessboardDetector:
14
  def __init__(self,
core/helper_34.py DELETED
@@ -1,316 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- from typing import Tuple, List
4
-
5
-
6
- BONE_NAMES = [
7
- "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8",
8
- "J0", "J1", "J2", "J3", "J4", "J5", "J6", "J7", "J8",
9
- "B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
10
- "B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
11
- ]
12
-
13
- def check_keypoints(keypoints: np.ndarray):
14
- """
15
- 检查关键点坐标是否正确
16
- @param keypoints: 关键点坐标, shape 为 (34, 2)
17
- """
18
- if keypoints.shape != (34, 2):
19
- raise Exception(f"keypoints shape error: {keypoints.shape}")
20
-
21
-
22
- def build_cells_xywh_by_cronners(corner_points: np.ndarray, padding: int = 3) -> np.ndarray:
23
- """
24
- 根据 棋盘的 corner 点坐标 计算 每个位置的 xywh
25
- @param corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2)
26
- @param padding: 棋盘边框 padding
27
-
28
- @return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
29
- """
30
-
31
- if corner_points.shape != (4, 2):
32
- raise Exception(f"corner_points shape error: {corner_points.shape}")
33
-
34
- top_left_xy = corner_points[0]
35
- top_right_xy = corner_points[1]
36
- bottom_left_xy = corner_points[2]
37
- bottom_right_xy = corner_points[3]
38
-
39
- # 计算 每个框的 w 和 h
40
- item_w = (top_right_xy[0] - top_left_xy[0]) / (9 - 1)
41
- item_h = (bottom_left_xy[1] - top_left_xy[1]) / (10 - 1)
42
-
43
- item_w = item_w
44
- item_h = item_h
45
-
46
- item_w_with_padding = item_w - padding * 2
47
- item_h_with_padding = item_h - padding * 2
48
-
49
- # 计算 每个框的 center 坐标
50
- cells_xywh = np.zeros((10, 9, 4))
51
-
52
- for i in range(10):
53
- for j in range(9):
54
- center_x = top_left_xy[0] + item_w * j
55
- center_y = top_left_xy[1] + item_h * i
56
-
57
- cells_xywh[i, j] = [center_x, center_y, item_w_with_padding, item_h_with_padding]
58
-
59
- return cells_xywh
60
-
61
-
62
-
63
- # todo: 需要优化
64
- def build_cells_xywh(keypoints: np.ndarray, width: int = 450, height: int = 500, padding: int = 3) -> np.ndarray:
65
- """
66
- @param keypoints: 关键点坐标, shape 为 (34, 2)
67
- @param width: 棋盘宽度
68
- @param height: 棋盘高度
69
- @param padding: 棋盘边框 padding
70
- @return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
71
- """
72
- check_keypoints(keypoints)
73
-
74
-
75
- # 生成 A0 到 J8 的坐标, 如 B1 坐标 为 A1-J1 与 B0-B8 的交集点
76
- cells_xywh = np.zeros((10, 9, 4), dtype=np.int16)
77
-
78
- # 遍历 full_points 的每个点,计算其坐标
79
- for i in range(10):
80
- for j in range(9):
81
- # 计算 第 i 行 第 j 列 的坐标
82
- row_name = chr(ord('A') + i)
83
- col_name = str(j)
84
- flag_name = f"{row_name}{col_name}"
85
- if flag_name in BONE_NAMES:
86
- # 计算 第 i 行 第 j 列 的坐标
87
- cur_xy = keypoints[BONE_NAMES.index(flag_name)]
88
- cells_xywh[i, j] = [cur_xy[0], cur_xy[1], 0, 0]
89
- else:
90
- # 计算 第 i 行 第 j 列 的坐标
91
- row_start_name = f"{row_name}0"
92
- row_end_name = f"{row_name}8"
93
-
94
- col_start_name = f"A{col_name}"
95
- col_end_name = f"J{col_name}"
96
-
97
- row_start_xy = keypoints[BONE_NAMES.index(row_start_name)]
98
- row_end_xy = keypoints[BONE_NAMES.index(row_end_name)]
99
-
100
- col_start_xy = keypoints[BONE_NAMES.index(col_start_name)]
101
- col_end_xy = keypoints[BONE_NAMES.index(col_end_name)]
102
-
103
- # 计算 row_start_xy 到 row_end_xy 的直线 与 col_start_xy 到 col_end_xy 的直线 的交点
104
- # 使用参数方程法计算交点
105
- x1, y1 = row_start_xy # 横向直线起点
106
- x2, y2 = row_end_xy # 横向直线终点
107
- x3, y3 = col_start_xy # 纵向直线起点
108
- x4, y4 = col_end_xy # 纵向直线终点
109
-
110
- # 计算交点坐标
111
- # 使用克莱姆法则求解
112
- denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
113
-
114
- # 计算交点的 x 坐标
115
- x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
116
- # 计算交点的 y 坐标
117
- y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
118
-
119
- cells_xywh[i, j] = [int(x), int(y), 0, 0]
120
-
121
- # 计算每个点位的 wh
122
- for i in range(10):
123
- for j in range(9):
124
- cur_xy = cells_xywh[i, j]
125
- # 获取上下左右 4 个点, 根据 4 个点计算 wh, 宽高为 4 个点 计算出来的 x1y1x2y2 的距离 的 1/2
126
- if i == 0:
127
- # [i+1, j] 的 反向点
128
- up_xy = 2 * cur_xy - cells_xywh[i+1, j]
129
- else:
130
- up_xy = cells_xywh[i - 1, j]
131
-
132
- if i == 9:
133
- # [i-1, j] 的 反向点
134
- down_xy = 2 * cur_xy - cells_xywh[i-1, j]
135
- else:
136
- down_xy = cells_xywh[i+1, j]
137
-
138
- if j == 0:
139
- left_xy = 2 * cur_xy - cells_xywh[i, j+1]
140
- else:
141
- left_xy = cells_xywh[i, j-1]
142
-
143
- if j == 8:
144
- right_xy = 2 * cur_xy - cells_xywh[i, j-1]
145
- else:
146
- right_xy = cells_xywh[i, j+1]
147
-
148
- min_x = min(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
149
- min_y = min(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
150
-
151
- min_x += padding
152
- min_y += padding
153
-
154
- # 防止 min_x 和 min_y 为 0
155
- min_x = max(min_x, 1)
156
- min_y = max(min_y, 1)
157
-
158
- max_x = max(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
159
- max_y = max(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
160
-
161
- max_x -= padding
162
- max_y -= padding
163
-
164
- # 防止 max_x 和 max_y 超出边界
165
- max_x = min(max_x, width - 1)
166
- max_y = min(max_y, height - 1)
167
-
168
- w = (max_x - min_x) / 2
169
- h = (max_y - min_y) / 2
170
-
171
- cells_xywh[i, j] = [int(cur_xy[0]), int(cur_xy[1]), int(w), int(h)]
172
-
173
- return cells_xywh
174
-
175
-
176
- def perspective_transform(
177
- image: cv2.UMat,
178
- src_points: np.ndarray,
179
- keypoints: np.ndarray,
180
- dst_size=(450, 500)) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
181
- """
182
- 透视变换
183
- @param image: 图片
184
- @param src_points: 源点坐标
185
- @param keypoints: 关键点坐标
186
- @param dst_size: 目标尺寸 (width, height) 10 行 9 列
187
-
188
- @return:
189
- result: 透视变换后的图片
190
- transformed_keypoints: 透视变换后的关键点坐标
191
- corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
192
- """
193
-
194
- check_keypoints(keypoints)
195
-
196
-
197
- # 源点和目标点
198
- src = np.float32(src_points)
199
- padding = 50
200
- corner_points = np.float32([
201
- # 左上角
202
- [padding, padding],
203
- # 右上角
204
- [dst_size[0]-padding, padding],
205
- # 左下角
206
- [padding, dst_size[1]-padding],
207
- # 右下角
208
- [dst_size[0]-padding, dst_size[1]-padding]])
209
-
210
- # 计算透视变换矩阵
211
- matrix = cv2.getPerspectiveTransform(src, corner_points)
212
-
213
- # 执行透视变换
214
- result = cv2.warpPerspective(image, matrix, dst_size)
215
-
216
- # 重塑数组为要求的格式 (N,1,2)
217
- keypoints_reshaped = keypoints.reshape(-1, 1, 2).astype(np.float32)
218
- transformed_keypoints = cv2.perspectiveTransform(keypoints_reshaped, matrix)
219
- # 转回原来的形状
220
- transformed_keypoints = transformed_keypoints.reshape(-1, 2)
221
-
222
- return result, transformed_keypoints, corner_points
223
-
224
-
225
-
226
- def get_board_corner_points(keypoints: np.ndarray) -> np.ndarray:
227
- """
228
- 计算棋局四个边角的 points
229
- @param keypoints: 关键点坐标, shape 为 (34, 2)
230
- @return: 边角的坐标, shape 为 (4, 2)
231
- """
232
- check_keypoints(keypoints)
233
-
234
- # 找到 A0 A8 J0 J8 的坐标 以及 A4 和 J4 的坐标
235
- a0_index = BONE_NAMES.index("A0")
236
- a8_index = BONE_NAMES.index("A8")
237
- j0_index = BONE_NAMES.index("J0")
238
- j8_index = BONE_NAMES.index("J8")
239
-
240
- a0_xy = keypoints[a0_index]
241
- a8_xy = keypoints[a8_index]
242
- j0_xy = keypoints[j0_index]
243
- j8_xy = keypoints[j8_index]
244
-
245
- # 计算新的四个角点坐标
246
- dst_points = np.array([
247
- a0_xy,
248
- a8_xy,
249
- j0_xy,
250
- j8_xy
251
- ], dtype=np.float32)
252
-
253
- return dst_points
254
-
255
- def extract_chessboard(img: cv2.UMat, keypoints: np.ndarray) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
256
- """
257
- 提取棋盘信息
258
- @param img: 图片
259
- @param keypoints: 关键点坐标, shape 为 (34, 2)
260
- @return:
261
- transformed_image: 透视变换后的图片
262
- transformed_keypoints: 透视变换后的关键点坐标
263
- transformed_corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
264
- """
265
-
266
- check_keypoints(keypoints)
267
-
268
- source_corner_points = get_board_corner_points(keypoints)
269
-
270
- transformed_image, transformed_keypoints, transformed_corner_points = perspective_transform(img, source_corner_points, keypoints)
271
-
272
- return transformed_image, transformed_keypoints, transformed_corner_points
273
-
274
-
275
- def collect_cells_images(image: cv2.UMat, cells_xywh: np.ndarray) -> List[List[np.ndarray]]:
276
- """
277
- 收集 棋盘的 cells_xywh 对应的图片集合
278
- """
279
- width = image.shape[1]
280
- height = image.shape[0]
281
- crop_cells: List[List[np.ndarray]] = []
282
-
283
- for i in range(10):
284
- row_cells = []
285
- for j in range(9):
286
- x, y, w, h = cells_xywh[i, j]
287
-
288
- x_0 = max(int(x-w/2), 0)
289
- y_0 = max(int(y-h/2), 0)
290
- x_1 = min(int(x+w/2), width-1)
291
- y_1 = min(int(y+h/2), height-1)
292
-
293
- crop_img = image[y_0:y_1, x_0:x_1]
294
- row_cells.append(crop_img)
295
- crop_cells.append(row_cells)
296
-
297
- return crop_cells
298
-
299
- def draw_cells_box(image: cv2.UMat, cells_xywh: np.ndarray) -> cv2.UMat:
300
- """
301
- 绘制 棋盘的 cells_xywh 对应的 矩形框
302
- """
303
- width = image.shape[1]
304
- height = image.shape[0]
305
- for i in range(10):
306
- for j in range(9):
307
- x, y, w, h = cells_xywh[i, j]
308
-
309
- x_0 = max(int(x-w/2), 0)
310
- y_0 = max(int(y-h/2), 0)
311
- x_1 = min(int(x+w/2), width-1)
312
- y_1 = min(int(y+h/2), height-1)
313
-
314
- cv2.rectangle(image,(x_0, y_0), (x_1, y_1), (0, 0, 255), 1)
315
-
316
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/helper_4_kpt.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from typing import Tuple, List
4
+
5
+
6
+ BONE_NAMES = [
7
+ "A0", "A8",
8
+ "J0", "J8",
9
+ ]
10
+
11
+ def check_keypoints(keypoints: np.ndarray):
12
+ """
13
+ 检查关键点坐标是否正确
14
+ @param keypoints: 关键点坐标, shape 为 (N, 2)
15
+ """
16
+ if keypoints.shape != (len(BONE_NAMES), 2):
17
+ raise Exception(f"keypoints shape error: {keypoints.shape}")
18
+ def perspective_transform(
19
+ image: cv2.UMat,
20
+ src_points: np.ndarray,
21
+ keypoints: np.ndarray,
22
+ dst_size=(450, 500)) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
23
+ """
24
+ 透视变换
25
+ @param image: 图片
26
+ @param src_points: 源点坐标
27
+ @param keypoints: 关键点坐标
28
+ @param dst_size: 目标尺寸 (width, height) 10 行 9 列
29
+
30
+ @return:
31
+ result: 透视变换后的图片
32
+ transformed_keypoints: 透视变换后的关键点坐标
33
+ corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
34
+ """
35
+
36
+ check_keypoints(keypoints)
37
+
38
+
39
+ # 源点和目标点
40
+ src = np.float32(src_points)
41
+ padding = 50
42
+ corner_points = np.float32([
43
+ # 左上角
44
+ [padding, padding],
45
+ # 右上角
46
+ [dst_size[0]-padding, padding],
47
+ # 左下角
48
+ [padding, dst_size[1]-padding],
49
+ # 右下角
50
+ [dst_size[0]-padding, dst_size[1]-padding]])
51
+
52
+ # 计算透视变换矩阵
53
+ matrix = cv2.getPerspectiveTransform(src, corner_points)
54
+
55
+ # 执行透视变换
56
+ result = cv2.warpPerspective(image, matrix, dst_size)
57
+
58
+ # 重塑数组为要求的格式 (N,1,2)
59
+ keypoints_reshaped = keypoints.reshape(-1, 1, 2).astype(np.float32)
60
+ transformed_keypoints = cv2.perspectiveTransform(keypoints_reshaped, matrix)
61
+ # 转回原来的形状
62
+ transformed_keypoints = transformed_keypoints.reshape(-1, 2)
63
+
64
+ return result, transformed_keypoints, corner_points
65
+
66
+
67
+
68
+ def get_board_corner_points(keypoints: np.ndarray) -> np.ndarray:
69
+ """
70
+ 计算棋局四个边角的 points
71
+ @param keypoints: 关键点坐标, shape 为 (N, 2)
72
+ @return: 边角的坐标, shape 为 (4, 2)
73
+ """
74
+ check_keypoints(keypoints)
75
+
76
+ # 找到 A0 A8 J0 J8 的坐标 以及 A4 和 J4 的坐标
77
+ a0_index = BONE_NAMES.index("A0")
78
+ a8_index = BONE_NAMES.index("A8")
79
+ j0_index = BONE_NAMES.index("J0")
80
+ j8_index = BONE_NAMES.index("J8")
81
+
82
+ a0_xy = keypoints[a0_index]
83
+ a8_xy = keypoints[a8_index]
84
+ j0_xy = keypoints[j0_index]
85
+ j8_xy = keypoints[j8_index]
86
+
87
+ # 计算新的四个角点坐标
88
+ dst_points = np.array([
89
+ a0_xy,
90
+ a8_xy,
91
+ j0_xy,
92
+ j8_xy
93
+ ], dtype=np.float32)
94
+
95
+ return dst_points
96
+
97
+ def extract_chessboard(img: cv2.UMat, keypoints: np.ndarray) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
98
+ """
99
+ 提取棋盘信息
100
+ @param img: 图片
101
+ @param keypoints: 关键点坐标, shape 为 (N, 2)
102
+ @return:
103
+ transformed_image: 透视变换后的图片
104
+ transformed_keypoints: 透视变换后的关键点坐标
105
+ transformed_corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
106
+ """
107
+
108
+ check_keypoints(keypoints)
109
+
110
+ source_corner_points = get_board_corner_points(keypoints)
111
+
112
+ transformed_image, transformed_keypoints, transformed_corner_points = perspective_transform(img, source_corner_points, keypoints)
113
+
114
+ return transformed_image, transformed_keypoints, transformed_corner_points
core/{kpt_34_with_xanything.py → kpt_4_with_xanything.py} RENAMED
@@ -3,7 +3,7 @@ import os
3
  import json
4
  import numpy as np
5
 
6
- from .helper_34 import BONE_NAMES
7
 
8
  class Shape:
9
 
@@ -45,7 +45,7 @@ class KeyPoint(Shape):
45
  super().__init__(label, [point_xy], group_id, "point")
46
 
47
  class Rectangle(Shape):
48
- def __init__(self, label="A1", xyxy=list[float, float, float, float], group_id=1):
49
 
50
  if len(xyxy) != 4:
51
  raise ValueError("xyxy 必须是一个包含 4 个元素的列表")
@@ -114,9 +114,9 @@ class Annotation:
114
  }
115
 
116
 
117
- def save_kpt_34_with_xanything(image_input: np.ndarray, image_ann_path, bbox: list[float, float, float, float], kpt_34: list[tuple[str, float, float]], save_dir: str):
118
  """
119
- 保存 34 个关键点 和 一个 bbox 到 xanything 的 json 文件
120
  """
121
  x1, y1, x2, y2 = bbox
122
  x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
@@ -137,12 +137,12 @@ def save_kpt_34_with_xanything(image_input: np.ndarray, image_ann_path, bbox: li
137
 
138
  annotation = Annotation(file_name, image_width, image_height)
139
 
140
- kpt_34_dict = {}
141
- for bone_name, x, y in kpt_34:
142
- kpt_34_dict[bone_name] = [float(x), float(y)]
143
 
144
  for bone_name in BONE_NAMES:
145
- x, y = kpt_34_dict[bone_name]
146
  annotation.add_shape(KeyPoint(bone_name, [x, y]))
147
 
148
  # 添加 bbox
@@ -171,22 +171,22 @@ def read_xanything_to_json(json_path) -> tuple[list[tuple[str, float, float]], l
171
  # data
172
  annotation = Annotation.init_from_dict(data)
173
 
174
- keypoints_34_dict: dict[str, list[float, float]] = {}
175
  # x1, y1, x2, y2
176
  bbox: list[float, float, float, float] = []
177
 
178
  for shape in annotation.shapes:
179
  if shape["shape_type"] == "point":
180
- keypoints_34_dict[shape["label"]] = [shape["points"][0][0], shape["points"][0][1]]
181
  elif shape["shape_type"] == "rectangle":
182
  bbox = [shape["points"][0][0], shape["points"][0][1], shape["points"][2][0], shape["points"][2][1]]
183
 
184
- keypoints_34: list[tuple[str, float, float]] = []
185
 
186
  for item in BONE_NAMES:
187
- keypoints_34.append((item, keypoints_34_dict[item][0], keypoints_34_dict[item][1]))
188
 
189
- return keypoints_34, bbox
190
 
191
 
192
 
 
3
  import json
4
  import numpy as np
5
 
6
+ from .helper_4_kpt import BONE_NAMES
7
 
8
  class Shape:
9
 
 
45
  super().__init__(label, [point_xy], group_id, "point")
46
 
47
  class Rectangle(Shape):
48
+ def __init__(self, label="A0", xyxy=list[float, float, float, float], group_id=1):
49
 
50
  if len(xyxy) != 4:
51
  raise ValueError("xyxy 必须是一个包含 4 个元素的列表")
 
114
  }
115
 
116
 
117
+ def save_kpt_4_with_xanything(image_input: np.ndarray, image_ann_path, bbox: list[float, float, float, float], kpt_4: list[tuple[str, float, float]], save_dir: str):
118
  """
119
+ 保存 4 个关键点 和 一个 bbox 到 xanything 的 json 文件
120
  """
121
  x1, y1, x2, y2 = bbox
122
  x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
 
137
 
138
  annotation = Annotation(file_name, image_width, image_height)
139
 
140
+ kpt_4_dict = {}
141
+ for bone_name, x, y in kpt_4:
142
+ kpt_4_dict[bone_name] = [float(x), float(y)]
143
 
144
  for bone_name in BONE_NAMES:
145
+ x, y = kpt_4_dict[bone_name]
146
  annotation.add_shape(KeyPoint(bone_name, [x, y]))
147
 
148
  # 添加 bbox
 
171
  # data
172
  annotation = Annotation.init_from_dict(data)
173
 
174
+ keypoints_4_dict: dict[str, list[float, float]] = {}
175
  # x1, y1, x2, y2
176
  bbox: list[float, float, float, float] = []
177
 
178
  for shape in annotation.shapes:
179
  if shape["shape_type"] == "point":
180
+ keypoints_4_dict[shape["label"]] = [shape["points"][0][0], shape["points"][0][1]]
181
  elif shape["shape_type"] == "rectangle":
182
  bbox = [shape["points"][0][0], shape["points"][0][1], shape["points"][2][0], shape["points"][2][1]]
183
 
184
+ keypoints_4: list[tuple[str, float, float]] = []
185
 
186
  for item in BONE_NAMES:
187
+ keypoints_4.append((item, keypoints_4_dict[item][0], keypoints_4_dict[item][1]))
188
 
189
+ return keypoints_4, bbox
190
 
191
 
192
 
core/runonnx/rtmpose.py CHANGED
@@ -6,16 +6,35 @@ from .base_onnx import BaseONNX
6
  class RTMPOSE_ONNX(BaseONNX):
7
 
8
  bone_names = [
9
- "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8",
10
- "J0", "J1", "J2", "J3", "J4", "J5", "J6", "J7", "J8",
11
- "B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
12
- "B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
13
  ]
14
 
15
- def __init__(self, model_path, input_size=(256, 256), padding=1.25):
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  super().__init__(model_path, input_size)
17
  self.padding = padding
18
 
 
 
 
 
 
 
 
 
19
 
20
  def get_bbox_center_scale(self, bbox: List[int]):
21
  """Convert bounding box to center and scale.
@@ -202,8 +221,8 @@ class RTMPOSE_ONNX(BaseONNX):
202
  def get_simcc_maximum(self, simcc_x, simcc_y):
203
 
204
  # 在最后一维上找到最大值的索引
205
- x_indices = np.argmax(simcc_x[0], axis=1) # (34,)
206
- y_indices = np.argmax(simcc_y[0], axis=1) # (34,)
207
 
208
 
209
  input_w, input_h = self.input_size
@@ -213,7 +232,7 @@ class RTMPOSE_ONNX(BaseONNX):
213
  y_coords = y_indices / (input_h * 2)
214
 
215
  # 组合成坐标对
216
- keypoints = np.stack([x_coords, y_coords], axis=1) # (34, 2)
217
 
218
  # 获取每个点的置信度分数
219
  scores = np.max(simcc_x[0], axis=1) * np.max(simcc_y[0], axis=1)
@@ -339,8 +358,7 @@ class RTMPOSE_ONNX(BaseONNX):
339
  if not is_rgb:
340
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
341
 
342
- # 获取 随机的 34 中颜色
343
- colors = np.random.randint(0, 256, (34, 3))
344
 
345
  for i, (point, score) in enumerate(zip(keypoints, scores)):
346
  if score > 0.3: # 设置置信度阈值
@@ -352,5 +370,25 @@ class RTMPOSE_ONNX(BaseONNX):
352
  # 添加关键点索引标注
353
  cv2.putText(img, self.bone_names[i], (x+5, y+5),
354
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  return img
356
 
 
6
  class RTMPOSE_ONNX(BaseONNX):
7
 
8
  bone_names = [
9
+ "A0", "A8",
10
+ "J0", "J8",
 
 
11
  ]
12
 
13
+
14
+ skeleton_links = [
15
+ "A0-A8",
16
+ "A8-J8",
17
+ "J8-J0",
18
+ "J0-A0",
19
+ ]
20
+
21
+ def __init__(self,
22
+ model_path, input_size=(256, 256),
23
+ padding=1.25,
24
+ bone_names=None,
25
+ skeleton_links=None,
26
+ ):
27
  super().__init__(model_path, input_size)
28
  self.padding = padding
29
 
30
+ if bone_names is not None:
31
+ self.bone_names = bone_names
32
+
33
+ if skeleton_links is not None:
34
+ self.skeleton_links = skeleton_links
35
+
36
+ self.bone_colors = np.random.randint(0, 256, (len(self.bone_names), 3))
37
+
38
 
39
  def get_bbox_center_scale(self, bbox: List[int]):
40
  """Convert bounding box to center and scale.
 
221
  def get_simcc_maximum(self, simcc_x, simcc_y):
222
 
223
  # 在最后一维上找到最大值的索引
224
+ x_indices = np.argmax(simcc_x[0], axis=1) # (N,)
225
+ y_indices = np.argmax(simcc_y[0], axis=1) # (N,)
226
 
227
 
228
  input_w, input_h = self.input_size
 
232
  y_coords = y_indices / (input_h * 2)
233
 
234
  # 组合成坐标对
235
+ keypoints = np.stack([x_coords, y_coords], axis=1) # (N, 2)
236
 
237
  # 获取每个点的置信度分数
238
  scores = np.max(simcc_x[0], axis=1) * np.max(simcc_y[0], axis=1)
 
358
  if not is_rgb:
359
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
360
 
361
+ colors = self.bone_colors
 
362
 
363
  for i, (point, score) in enumerate(zip(keypoints, scores)):
364
  if score > 0.3: # 设置置信度阈值
 
370
  # 添加关键点索引标注
371
  cv2.putText(img, self.bone_names[i], (x+5, y+5),
372
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
373
+
374
+ # 绘制 关节连接线
375
+ for link in self.skeleton_links:
376
+ start_bone, end_bone = link.split("-")
377
+
378
+ start_index = self.bone_names.index(start_bone)
379
+ end_index = self.bone_names.index(end_bone)
380
+
381
+ start_keypoint = keypoints[start_index]
382
+ end_keypoint = keypoints[end_index]
383
+ link_color = colors[start_index]
384
+
385
+ # 绘制连线
386
+ if scores[start_index] > 0.3 and scores[end_index] > 0.3:
387
+ start_point = tuple(map(int, start_keypoint))
388
+ end_point = tuple(map(int, end_keypoint))
389
+ cv2.line(img, start_point, end_point,
390
+ (int(link_color[0]), int(link_color[1]), int(link_color[2])),
391
+ thickness=2)
392
+
393
  return img
394