feat: kpt 4
Browse files- HISTORY.md +4 -0
- app.py +1 -1
- core/chessboard_detector.py +1 -1
- core/helper_34.py +0 -316
- core/helper_4_kpt.py +114 -0
- core/{kpt_34_with_xanything.py → kpt_4_with_xanything.py} +13 -13
- core/runonnx/rtmpose.py +48 -10
HISTORY.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 2024-12-28
|
2 |
+
|
3 |
+
1. 使用 4 个关键点检测
|
4 |
+
|
app.py
CHANGED
@@ -5,7 +5,7 @@ from core.chessboard_detector import ChessboardDetector
|
|
5 |
|
6 |
detector = ChessboardDetector(
|
7 |
det_model_path="onnx/det/v1.onnx",
|
8 |
-
pose_model_path="onnx/pose/
|
9 |
full_classifier_model_path="onnx/layout_recognition/v1.onnx"
|
10 |
)
|
11 |
|
|
|
5 |
|
6 |
detector = ChessboardDetector(
|
7 |
det_model_path="onnx/det/v1.onnx",
|
8 |
+
pose_model_path="onnx/pose/4_v2.onnx",
|
9 |
full_classifier_model_path="onnx/layout_recognition/v1.onnx"
|
10 |
)
|
11 |
|
core/chessboard_detector.py
CHANGED
@@ -8,7 +8,7 @@ from .runonnx.rtmdet import RTMDET_ONNX
|
|
8 |
from .runonnx.rtmpose import RTMPOSE_ONNX
|
9 |
from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
|
10 |
|
11 |
-
from core.
|
12 |
|
13 |
class ChessboardDetector:
|
14 |
def __init__(self,
|
|
|
8 |
from .runonnx.rtmpose import RTMPOSE_ONNX
|
9 |
from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
|
10 |
|
11 |
+
from core.helper_4_kpt import extract_chessboard
|
12 |
|
13 |
class ChessboardDetector:
|
14 |
def __init__(self,
|
core/helper_34.py
DELETED
@@ -1,316 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import cv2
|
3 |
-
from typing import Tuple, List
|
4 |
-
|
5 |
-
|
6 |
-
BONE_NAMES = [
|
7 |
-
"A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8",
|
8 |
-
"J0", "J1", "J2", "J3", "J4", "J5", "J6", "J7", "J8",
|
9 |
-
"B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
|
10 |
-
"B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
|
11 |
-
]
|
12 |
-
|
13 |
-
def check_keypoints(keypoints: np.ndarray):
|
14 |
-
"""
|
15 |
-
检查关键点坐标是否正确
|
16 |
-
@param keypoints: 关键点坐标, shape 为 (34, 2)
|
17 |
-
"""
|
18 |
-
if keypoints.shape != (34, 2):
|
19 |
-
raise Exception(f"keypoints shape error: {keypoints.shape}")
|
20 |
-
|
21 |
-
|
22 |
-
def build_cells_xywh_by_cronners(corner_points: np.ndarray, padding: int = 3) -> np.ndarray:
|
23 |
-
"""
|
24 |
-
根据 棋盘的 corner 点坐标 计算 每个位置的 xywh
|
25 |
-
@param corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2)
|
26 |
-
@param padding: 棋盘边框 padding
|
27 |
-
|
28 |
-
@return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
|
29 |
-
"""
|
30 |
-
|
31 |
-
if corner_points.shape != (4, 2):
|
32 |
-
raise Exception(f"corner_points shape error: {corner_points.shape}")
|
33 |
-
|
34 |
-
top_left_xy = corner_points[0]
|
35 |
-
top_right_xy = corner_points[1]
|
36 |
-
bottom_left_xy = corner_points[2]
|
37 |
-
bottom_right_xy = corner_points[3]
|
38 |
-
|
39 |
-
# 计算 每个框的 w 和 h
|
40 |
-
item_w = (top_right_xy[0] - top_left_xy[0]) / (9 - 1)
|
41 |
-
item_h = (bottom_left_xy[1] - top_left_xy[1]) / (10 - 1)
|
42 |
-
|
43 |
-
item_w = item_w
|
44 |
-
item_h = item_h
|
45 |
-
|
46 |
-
item_w_with_padding = item_w - padding * 2
|
47 |
-
item_h_with_padding = item_h - padding * 2
|
48 |
-
|
49 |
-
# 计算 每个框的 center 坐标
|
50 |
-
cells_xywh = np.zeros((10, 9, 4))
|
51 |
-
|
52 |
-
for i in range(10):
|
53 |
-
for j in range(9):
|
54 |
-
center_x = top_left_xy[0] + item_w * j
|
55 |
-
center_y = top_left_xy[1] + item_h * i
|
56 |
-
|
57 |
-
cells_xywh[i, j] = [center_x, center_y, item_w_with_padding, item_h_with_padding]
|
58 |
-
|
59 |
-
return cells_xywh
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
# todo: 需要优化
|
64 |
-
def build_cells_xywh(keypoints: np.ndarray, width: int = 450, height: int = 500, padding: int = 3) -> np.ndarray:
|
65 |
-
"""
|
66 |
-
@param keypoints: 关键点坐标, shape 为 (34, 2)
|
67 |
-
@param width: 棋盘宽度
|
68 |
-
@param height: 棋盘高度
|
69 |
-
@param padding: 棋盘边框 padding
|
70 |
-
@return: 棋盘的 xywh, shape 为 (10, 9, 4), 4 为 center_x, center_y, w, h
|
71 |
-
"""
|
72 |
-
check_keypoints(keypoints)
|
73 |
-
|
74 |
-
|
75 |
-
# 生成 A0 到 J8 的坐标, 如 B1 坐标 为 A1-J1 与 B0-B8 的交集点
|
76 |
-
cells_xywh = np.zeros((10, 9, 4), dtype=np.int16)
|
77 |
-
|
78 |
-
# 遍历 full_points 的每个点,计算其坐标
|
79 |
-
for i in range(10):
|
80 |
-
for j in range(9):
|
81 |
-
# 计算 第 i 行 第 j 列 的坐标
|
82 |
-
row_name = chr(ord('A') + i)
|
83 |
-
col_name = str(j)
|
84 |
-
flag_name = f"{row_name}{col_name}"
|
85 |
-
if flag_name in BONE_NAMES:
|
86 |
-
# 计算 第 i 行 第 j 列 的坐标
|
87 |
-
cur_xy = keypoints[BONE_NAMES.index(flag_name)]
|
88 |
-
cells_xywh[i, j] = [cur_xy[0], cur_xy[1], 0, 0]
|
89 |
-
else:
|
90 |
-
# 计算 第 i 行 第 j 列 的坐标
|
91 |
-
row_start_name = f"{row_name}0"
|
92 |
-
row_end_name = f"{row_name}8"
|
93 |
-
|
94 |
-
col_start_name = f"A{col_name}"
|
95 |
-
col_end_name = f"J{col_name}"
|
96 |
-
|
97 |
-
row_start_xy = keypoints[BONE_NAMES.index(row_start_name)]
|
98 |
-
row_end_xy = keypoints[BONE_NAMES.index(row_end_name)]
|
99 |
-
|
100 |
-
col_start_xy = keypoints[BONE_NAMES.index(col_start_name)]
|
101 |
-
col_end_xy = keypoints[BONE_NAMES.index(col_end_name)]
|
102 |
-
|
103 |
-
# 计算 row_start_xy 到 row_end_xy 的直线 与 col_start_xy 到 col_end_xy 的直线 的交点
|
104 |
-
# 使用参数方程法计算交点
|
105 |
-
x1, y1 = row_start_xy # 横向直线起点
|
106 |
-
x2, y2 = row_end_xy # 横向直线终点
|
107 |
-
x3, y3 = col_start_xy # 纵向直线起点
|
108 |
-
x4, y4 = col_end_xy # 纵向直线终点
|
109 |
-
|
110 |
-
# 计算交点坐标
|
111 |
-
# 使用克莱姆法则求解
|
112 |
-
denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
|
113 |
-
|
114 |
-
# 计算交点的 x 坐标
|
115 |
-
x = ((x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)) / denominator
|
116 |
-
# 计算交点的 y 坐标
|
117 |
-
y = ((x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)) / denominator
|
118 |
-
|
119 |
-
cells_xywh[i, j] = [int(x), int(y), 0, 0]
|
120 |
-
|
121 |
-
# 计算每个点位的 wh
|
122 |
-
for i in range(10):
|
123 |
-
for j in range(9):
|
124 |
-
cur_xy = cells_xywh[i, j]
|
125 |
-
# 获取上下左右 4 个点, 根据 4 个点计算 wh, 宽高为 4 个点 计算出来的 x1y1x2y2 的距离 的 1/2
|
126 |
-
if i == 0:
|
127 |
-
# [i+1, j] 的 反向点
|
128 |
-
up_xy = 2 * cur_xy - cells_xywh[i+1, j]
|
129 |
-
else:
|
130 |
-
up_xy = cells_xywh[i - 1, j]
|
131 |
-
|
132 |
-
if i == 9:
|
133 |
-
# [i-1, j] 的 反向点
|
134 |
-
down_xy = 2 * cur_xy - cells_xywh[i-1, j]
|
135 |
-
else:
|
136 |
-
down_xy = cells_xywh[i+1, j]
|
137 |
-
|
138 |
-
if j == 0:
|
139 |
-
left_xy = 2 * cur_xy - cells_xywh[i, j+1]
|
140 |
-
else:
|
141 |
-
left_xy = cells_xywh[i, j-1]
|
142 |
-
|
143 |
-
if j == 8:
|
144 |
-
right_xy = 2 * cur_xy - cells_xywh[i, j-1]
|
145 |
-
else:
|
146 |
-
right_xy = cells_xywh[i, j+1]
|
147 |
-
|
148 |
-
min_x = min(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
|
149 |
-
min_y = min(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
|
150 |
-
|
151 |
-
min_x += padding
|
152 |
-
min_y += padding
|
153 |
-
|
154 |
-
# 防止 min_x 和 min_y 为 0
|
155 |
-
min_x = max(min_x, 1)
|
156 |
-
min_y = max(min_y, 1)
|
157 |
-
|
158 |
-
max_x = max(up_xy[0].tolist(), down_xy[0].tolist(), left_xy[0].tolist(), right_xy[0].tolist())
|
159 |
-
max_y = max(up_xy[1].tolist(), down_xy[1].tolist(), left_xy[1].tolist(), right_xy[1].tolist())
|
160 |
-
|
161 |
-
max_x -= padding
|
162 |
-
max_y -= padding
|
163 |
-
|
164 |
-
# 防止 max_x 和 max_y 超出边界
|
165 |
-
max_x = min(max_x, width - 1)
|
166 |
-
max_y = min(max_y, height - 1)
|
167 |
-
|
168 |
-
w = (max_x - min_x) / 2
|
169 |
-
h = (max_y - min_y) / 2
|
170 |
-
|
171 |
-
cells_xywh[i, j] = [int(cur_xy[0]), int(cur_xy[1]), int(w), int(h)]
|
172 |
-
|
173 |
-
return cells_xywh
|
174 |
-
|
175 |
-
|
176 |
-
def perspective_transform(
|
177 |
-
image: cv2.UMat,
|
178 |
-
src_points: np.ndarray,
|
179 |
-
keypoints: np.ndarray,
|
180 |
-
dst_size=(450, 500)) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
|
181 |
-
"""
|
182 |
-
透视变换
|
183 |
-
@param image: 图片
|
184 |
-
@param src_points: 源点坐标
|
185 |
-
@param keypoints: 关键点坐标
|
186 |
-
@param dst_size: 目标尺寸 (width, height) 10 行 9 列
|
187 |
-
|
188 |
-
@return:
|
189 |
-
result: 透视变换后的图片
|
190 |
-
transformed_keypoints: 透视变换后的关键点坐标
|
191 |
-
corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
|
192 |
-
"""
|
193 |
-
|
194 |
-
check_keypoints(keypoints)
|
195 |
-
|
196 |
-
|
197 |
-
# 源点和目标点
|
198 |
-
src = np.float32(src_points)
|
199 |
-
padding = 50
|
200 |
-
corner_points = np.float32([
|
201 |
-
# 左上角
|
202 |
-
[padding, padding],
|
203 |
-
# 右上角
|
204 |
-
[dst_size[0]-padding, padding],
|
205 |
-
# 左下角
|
206 |
-
[padding, dst_size[1]-padding],
|
207 |
-
# 右下角
|
208 |
-
[dst_size[0]-padding, dst_size[1]-padding]])
|
209 |
-
|
210 |
-
# 计算透视变换矩阵
|
211 |
-
matrix = cv2.getPerspectiveTransform(src, corner_points)
|
212 |
-
|
213 |
-
# 执行透视变换
|
214 |
-
result = cv2.warpPerspective(image, matrix, dst_size)
|
215 |
-
|
216 |
-
# 重塑数组为要求的格式 (N,1,2)
|
217 |
-
keypoints_reshaped = keypoints.reshape(-1, 1, 2).astype(np.float32)
|
218 |
-
transformed_keypoints = cv2.perspectiveTransform(keypoints_reshaped, matrix)
|
219 |
-
# 转回原来的形状
|
220 |
-
transformed_keypoints = transformed_keypoints.reshape(-1, 2)
|
221 |
-
|
222 |
-
return result, transformed_keypoints, corner_points
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
def get_board_corner_points(keypoints: np.ndarray) -> np.ndarray:
|
227 |
-
"""
|
228 |
-
计算棋局四个边角的 points
|
229 |
-
@param keypoints: 关键点坐标, shape 为 (34, 2)
|
230 |
-
@return: 边角的坐标, shape 为 (4, 2)
|
231 |
-
"""
|
232 |
-
check_keypoints(keypoints)
|
233 |
-
|
234 |
-
# 找到 A0 A8 J0 J8 的坐标 以及 A4 和 J4 的坐标
|
235 |
-
a0_index = BONE_NAMES.index("A0")
|
236 |
-
a8_index = BONE_NAMES.index("A8")
|
237 |
-
j0_index = BONE_NAMES.index("J0")
|
238 |
-
j8_index = BONE_NAMES.index("J8")
|
239 |
-
|
240 |
-
a0_xy = keypoints[a0_index]
|
241 |
-
a8_xy = keypoints[a8_index]
|
242 |
-
j0_xy = keypoints[j0_index]
|
243 |
-
j8_xy = keypoints[j8_index]
|
244 |
-
|
245 |
-
# 计算新的四个角点坐标
|
246 |
-
dst_points = np.array([
|
247 |
-
a0_xy,
|
248 |
-
a8_xy,
|
249 |
-
j0_xy,
|
250 |
-
j8_xy
|
251 |
-
], dtype=np.float32)
|
252 |
-
|
253 |
-
return dst_points
|
254 |
-
|
255 |
-
def extract_chessboard(img: cv2.UMat, keypoints: np.ndarray) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
|
256 |
-
"""
|
257 |
-
提取棋盘信息
|
258 |
-
@param img: 图片
|
259 |
-
@param keypoints: 关键点坐标, shape 为 (34, 2)
|
260 |
-
@return:
|
261 |
-
transformed_image: 透视变换后的图片
|
262 |
-
transformed_keypoints: 透视变换后的关键点坐标
|
263 |
-
transformed_corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
|
264 |
-
"""
|
265 |
-
|
266 |
-
check_keypoints(keypoints)
|
267 |
-
|
268 |
-
source_corner_points = get_board_corner_points(keypoints)
|
269 |
-
|
270 |
-
transformed_image, transformed_keypoints, transformed_corner_points = perspective_transform(img, source_corner_points, keypoints)
|
271 |
-
|
272 |
-
return transformed_image, transformed_keypoints, transformed_corner_points
|
273 |
-
|
274 |
-
|
275 |
-
def collect_cells_images(image: cv2.UMat, cells_xywh: np.ndarray) -> List[List[np.ndarray]]:
|
276 |
-
"""
|
277 |
-
收集 棋盘的 cells_xywh 对应的图片集合
|
278 |
-
"""
|
279 |
-
width = image.shape[1]
|
280 |
-
height = image.shape[0]
|
281 |
-
crop_cells: List[List[np.ndarray]] = []
|
282 |
-
|
283 |
-
for i in range(10):
|
284 |
-
row_cells = []
|
285 |
-
for j in range(9):
|
286 |
-
x, y, w, h = cells_xywh[i, j]
|
287 |
-
|
288 |
-
x_0 = max(int(x-w/2), 0)
|
289 |
-
y_0 = max(int(y-h/2), 0)
|
290 |
-
x_1 = min(int(x+w/2), width-1)
|
291 |
-
y_1 = min(int(y+h/2), height-1)
|
292 |
-
|
293 |
-
crop_img = image[y_0:y_1, x_0:x_1]
|
294 |
-
row_cells.append(crop_img)
|
295 |
-
crop_cells.append(row_cells)
|
296 |
-
|
297 |
-
return crop_cells
|
298 |
-
|
299 |
-
def draw_cells_box(image: cv2.UMat, cells_xywh: np.ndarray) -> cv2.UMat:
|
300 |
-
"""
|
301 |
-
绘制 棋盘的 cells_xywh 对应的 矩形框
|
302 |
-
"""
|
303 |
-
width = image.shape[1]
|
304 |
-
height = image.shape[0]
|
305 |
-
for i in range(10):
|
306 |
-
for j in range(9):
|
307 |
-
x, y, w, h = cells_xywh[i, j]
|
308 |
-
|
309 |
-
x_0 = max(int(x-w/2), 0)
|
310 |
-
y_0 = max(int(y-h/2), 0)
|
311 |
-
x_1 = min(int(x+w/2), width-1)
|
312 |
-
y_1 = min(int(y+h/2), height-1)
|
313 |
-
|
314 |
-
cv2.rectangle(image,(x_0, y_0), (x_1, y_1), (0, 0, 255), 1)
|
315 |
-
|
316 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/helper_4_kpt.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from typing import Tuple, List
|
4 |
+
|
5 |
+
|
6 |
+
BONE_NAMES = [
|
7 |
+
"A0", "A8",
|
8 |
+
"J0", "J8",
|
9 |
+
]
|
10 |
+
|
11 |
+
def check_keypoints(keypoints: np.ndarray):
|
12 |
+
"""
|
13 |
+
检查关键点坐标是否正确
|
14 |
+
@param keypoints: 关键点坐标, shape 为 (N, 2)
|
15 |
+
"""
|
16 |
+
if keypoints.shape != (len(BONE_NAMES), 2):
|
17 |
+
raise Exception(f"keypoints shape error: {keypoints.shape}")
|
18 |
+
def perspective_transform(
|
19 |
+
image: cv2.UMat,
|
20 |
+
src_points: np.ndarray,
|
21 |
+
keypoints: np.ndarray,
|
22 |
+
dst_size=(450, 500)) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
|
23 |
+
"""
|
24 |
+
透视变换
|
25 |
+
@param image: 图片
|
26 |
+
@param src_points: 源点坐标
|
27 |
+
@param keypoints: 关键点坐标
|
28 |
+
@param dst_size: 目标尺寸 (width, height) 10 行 9 列
|
29 |
+
|
30 |
+
@return:
|
31 |
+
result: 透视变换后的图片
|
32 |
+
transformed_keypoints: 透视变换后的关键点坐标
|
33 |
+
corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
|
34 |
+
"""
|
35 |
+
|
36 |
+
check_keypoints(keypoints)
|
37 |
+
|
38 |
+
|
39 |
+
# 源点和目标点
|
40 |
+
src = np.float32(src_points)
|
41 |
+
padding = 50
|
42 |
+
corner_points = np.float32([
|
43 |
+
# 左上角
|
44 |
+
[padding, padding],
|
45 |
+
# 右上角
|
46 |
+
[dst_size[0]-padding, padding],
|
47 |
+
# 左下角
|
48 |
+
[padding, dst_size[1]-padding],
|
49 |
+
# 右下角
|
50 |
+
[dst_size[0]-padding, dst_size[1]-padding]])
|
51 |
+
|
52 |
+
# 计算透视变换矩阵
|
53 |
+
matrix = cv2.getPerspectiveTransform(src, corner_points)
|
54 |
+
|
55 |
+
# 执行透视变换
|
56 |
+
result = cv2.warpPerspective(image, matrix, dst_size)
|
57 |
+
|
58 |
+
# 重塑数组为要求的格式 (N,1,2)
|
59 |
+
keypoints_reshaped = keypoints.reshape(-1, 1, 2).astype(np.float32)
|
60 |
+
transformed_keypoints = cv2.perspectiveTransform(keypoints_reshaped, matrix)
|
61 |
+
# 转回原来的形状
|
62 |
+
transformed_keypoints = transformed_keypoints.reshape(-1, 2)
|
63 |
+
|
64 |
+
return result, transformed_keypoints, corner_points
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def get_board_corner_points(keypoints: np.ndarray) -> np.ndarray:
|
69 |
+
"""
|
70 |
+
计算棋局四个边角的 points
|
71 |
+
@param keypoints: 关键点坐标, shape 为 (N, 2)
|
72 |
+
@return: 边角的坐标, shape 为 (4, 2)
|
73 |
+
"""
|
74 |
+
check_keypoints(keypoints)
|
75 |
+
|
76 |
+
# 找到 A0 A8 J0 J8 的坐标 以及 A4 和 J4 的坐标
|
77 |
+
a0_index = BONE_NAMES.index("A0")
|
78 |
+
a8_index = BONE_NAMES.index("A8")
|
79 |
+
j0_index = BONE_NAMES.index("J0")
|
80 |
+
j8_index = BONE_NAMES.index("J8")
|
81 |
+
|
82 |
+
a0_xy = keypoints[a0_index]
|
83 |
+
a8_xy = keypoints[a8_index]
|
84 |
+
j0_xy = keypoints[j0_index]
|
85 |
+
j8_xy = keypoints[j8_index]
|
86 |
+
|
87 |
+
# 计算新的四个角点坐标
|
88 |
+
dst_points = np.array([
|
89 |
+
a0_xy,
|
90 |
+
a8_xy,
|
91 |
+
j0_xy,
|
92 |
+
j8_xy
|
93 |
+
], dtype=np.float32)
|
94 |
+
|
95 |
+
return dst_points
|
96 |
+
|
97 |
+
def extract_chessboard(img: cv2.UMat, keypoints: np.ndarray) -> Tuple[cv2.UMat, np.ndarray, np.ndarray]:
|
98 |
+
"""
|
99 |
+
提取棋盘信息
|
100 |
+
@param img: 图片
|
101 |
+
@param keypoints: 关键点坐标, shape 为 (N, 2)
|
102 |
+
@return:
|
103 |
+
transformed_image: 透视变换后的图片
|
104 |
+
transformed_keypoints: 透视变换后的关键点坐标
|
105 |
+
transformed_corner_points: 棋盘的 corner 点坐标, shape 为 (4, 2) A0, A8, J0, J8
|
106 |
+
"""
|
107 |
+
|
108 |
+
check_keypoints(keypoints)
|
109 |
+
|
110 |
+
source_corner_points = get_board_corner_points(keypoints)
|
111 |
+
|
112 |
+
transformed_image, transformed_keypoints, transformed_corner_points = perspective_transform(img, source_corner_points, keypoints)
|
113 |
+
|
114 |
+
return transformed_image, transformed_keypoints, transformed_corner_points
|
core/{kpt_34_with_xanything.py → kpt_4_with_xanything.py}
RENAMED
@@ -3,7 +3,7 @@ import os
|
|
3 |
import json
|
4 |
import numpy as np
|
5 |
|
6 |
-
from .
|
7 |
|
8 |
class Shape:
|
9 |
|
@@ -45,7 +45,7 @@ class KeyPoint(Shape):
|
|
45 |
super().__init__(label, [point_xy], group_id, "point")
|
46 |
|
47 |
class Rectangle(Shape):
|
48 |
-
def __init__(self, label="
|
49 |
|
50 |
if len(xyxy) != 4:
|
51 |
raise ValueError("xyxy 必须是一个包含 4 个元素的列表")
|
@@ -114,9 +114,9 @@ class Annotation:
|
|
114 |
}
|
115 |
|
116 |
|
117 |
-
def
|
118 |
"""
|
119 |
-
保存
|
120 |
"""
|
121 |
x1, y1, x2, y2 = bbox
|
122 |
x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
|
@@ -137,12 +137,12 @@ def save_kpt_34_with_xanything(image_input: np.ndarray, image_ann_path, bbox: li
|
|
137 |
|
138 |
annotation = Annotation(file_name, image_width, image_height)
|
139 |
|
140 |
-
|
141 |
-
for bone_name, x, y in
|
142 |
-
|
143 |
|
144 |
for bone_name in BONE_NAMES:
|
145 |
-
x, y =
|
146 |
annotation.add_shape(KeyPoint(bone_name, [x, y]))
|
147 |
|
148 |
# 添加 bbox
|
@@ -171,22 +171,22 @@ def read_xanything_to_json(json_path) -> tuple[list[tuple[str, float, float]], l
|
|
171 |
# data
|
172 |
annotation = Annotation.init_from_dict(data)
|
173 |
|
174 |
-
|
175 |
# x1, y1, x2, y2
|
176 |
bbox: list[float, float, float, float] = []
|
177 |
|
178 |
for shape in annotation.shapes:
|
179 |
if shape["shape_type"] == "point":
|
180 |
-
|
181 |
elif shape["shape_type"] == "rectangle":
|
182 |
bbox = [shape["points"][0][0], shape["points"][0][1], shape["points"][2][0], shape["points"][2][1]]
|
183 |
|
184 |
-
|
185 |
|
186 |
for item in BONE_NAMES:
|
187 |
-
|
188 |
|
189 |
-
return
|
190 |
|
191 |
|
192 |
|
|
|
3 |
import json
|
4 |
import numpy as np
|
5 |
|
6 |
+
from .helper_4_kpt import BONE_NAMES
|
7 |
|
8 |
class Shape:
|
9 |
|
|
|
45 |
super().__init__(label, [point_xy], group_id, "point")
|
46 |
|
47 |
class Rectangle(Shape):
|
48 |
+
def __init__(self, label="A0", xyxy=list[float, float, float, float], group_id=1):
|
49 |
|
50 |
if len(xyxy) != 4:
|
51 |
raise ValueError("xyxy 必须是一个包含 4 个元素的列表")
|
|
|
114 |
}
|
115 |
|
116 |
|
117 |
+
def save_kpt_4_with_xanything(image_input: np.ndarray, image_ann_path, bbox: list[float, float, float, float], kpt_4: list[tuple[str, float, float]], save_dir: str):
|
118 |
"""
|
119 |
+
保存 4 个关键点 和 一个 bbox 到 xanything 的 json 文件
|
120 |
"""
|
121 |
x1, y1, x2, y2 = bbox
|
122 |
x1, y1, x2, y2 = float(x1), float(y1), float(x2), float(y2)
|
|
|
137 |
|
138 |
annotation = Annotation(file_name, image_width, image_height)
|
139 |
|
140 |
+
kpt_4_dict = {}
|
141 |
+
for bone_name, x, y in kpt_4:
|
142 |
+
kpt_4_dict[bone_name] = [float(x), float(y)]
|
143 |
|
144 |
for bone_name in BONE_NAMES:
|
145 |
+
x, y = kpt_4_dict[bone_name]
|
146 |
annotation.add_shape(KeyPoint(bone_name, [x, y]))
|
147 |
|
148 |
# 添加 bbox
|
|
|
171 |
# data
|
172 |
annotation = Annotation.init_from_dict(data)
|
173 |
|
174 |
+
keypoints_4_dict: dict[str, list[float, float]] = {}
|
175 |
# x1, y1, x2, y2
|
176 |
bbox: list[float, float, float, float] = []
|
177 |
|
178 |
for shape in annotation.shapes:
|
179 |
if shape["shape_type"] == "point":
|
180 |
+
keypoints_4_dict[shape["label"]] = [shape["points"][0][0], shape["points"][0][1]]
|
181 |
elif shape["shape_type"] == "rectangle":
|
182 |
bbox = [shape["points"][0][0], shape["points"][0][1], shape["points"][2][0], shape["points"][2][1]]
|
183 |
|
184 |
+
keypoints_4: list[tuple[str, float, float]] = []
|
185 |
|
186 |
for item in BONE_NAMES:
|
187 |
+
keypoints_4.append((item, keypoints_4_dict[item][0], keypoints_4_dict[item][1]))
|
188 |
|
189 |
+
return keypoints_4, bbox
|
190 |
|
191 |
|
192 |
|
core/runonnx/rtmpose.py
CHANGED
@@ -6,16 +6,35 @@ from .base_onnx import BaseONNX
|
|
6 |
class RTMPOSE_ONNX(BaseONNX):
|
7 |
|
8 |
bone_names = [
|
9 |
-
"A0", "
|
10 |
-
"J0", "
|
11 |
-
"B0", "C0", "D0", "E0", "F0", "G0", "H0", "I0",
|
12 |
-
"B8", "C8", "D8", "E8", "F8", "G8", "H8", "I8",
|
13 |
]
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
super().__init__(model_path, input_size)
|
17 |
self.padding = padding
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def get_bbox_center_scale(self, bbox: List[int]):
|
21 |
"""Convert bounding box to center and scale.
|
@@ -202,8 +221,8 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
202 |
def get_simcc_maximum(self, simcc_x, simcc_y):
|
203 |
|
204 |
# 在最后一维上找到最大值的索引
|
205 |
-
x_indices = np.argmax(simcc_x[0], axis=1) # (
|
206 |
-
y_indices = np.argmax(simcc_y[0], axis=1) # (
|
207 |
|
208 |
|
209 |
input_w, input_h = self.input_size
|
@@ -213,7 +232,7 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
213 |
y_coords = y_indices / (input_h * 2)
|
214 |
|
215 |
# 组合成坐标对
|
216 |
-
keypoints = np.stack([x_coords, y_coords], axis=1) # (
|
217 |
|
218 |
# 获取每个点的置信度分数
|
219 |
scores = np.max(simcc_x[0], axis=1) * np.max(simcc_y[0], axis=1)
|
@@ -339,8 +358,7 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
339 |
if not is_rgb:
|
340 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
341 |
|
342 |
-
|
343 |
-
colors = np.random.randint(0, 256, (34, 3))
|
344 |
|
345 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
346 |
if score > 0.3: # 设置置信度阈值
|
@@ -352,5 +370,25 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
352 |
# 添加关键点索引标注
|
353 |
cv2.putText(img, self.bone_names[i], (x+5, y+5),
|
354 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
return img
|
356 |
|
|
|
6 |
class RTMPOSE_ONNX(BaseONNX):
|
7 |
|
8 |
bone_names = [
|
9 |
+
"A0", "A8",
|
10 |
+
"J0", "J8",
|
|
|
|
|
11 |
]
|
12 |
|
13 |
+
|
14 |
+
skeleton_links = [
|
15 |
+
"A0-A8",
|
16 |
+
"A8-J8",
|
17 |
+
"J8-J0",
|
18 |
+
"J0-A0",
|
19 |
+
]
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
model_path, input_size=(256, 256),
|
23 |
+
padding=1.25,
|
24 |
+
bone_names=None,
|
25 |
+
skeleton_links=None,
|
26 |
+
):
|
27 |
super().__init__(model_path, input_size)
|
28 |
self.padding = padding
|
29 |
|
30 |
+
if bone_names is not None:
|
31 |
+
self.bone_names = bone_names
|
32 |
+
|
33 |
+
if skeleton_links is not None:
|
34 |
+
self.skeleton_links = skeleton_links
|
35 |
+
|
36 |
+
self.bone_colors = np.random.randint(0, 256, (len(self.bone_names), 3))
|
37 |
+
|
38 |
|
39 |
def get_bbox_center_scale(self, bbox: List[int]):
|
40 |
"""Convert bounding box to center and scale.
|
|
|
221 |
def get_simcc_maximum(self, simcc_x, simcc_y):
|
222 |
|
223 |
# 在最后一维上找到最大值的索引
|
224 |
+
x_indices = np.argmax(simcc_x[0], axis=1) # (N,)
|
225 |
+
y_indices = np.argmax(simcc_y[0], axis=1) # (N,)
|
226 |
|
227 |
|
228 |
input_w, input_h = self.input_size
|
|
|
232 |
y_coords = y_indices / (input_h * 2)
|
233 |
|
234 |
# 组合成坐标对
|
235 |
+
keypoints = np.stack([x_coords, y_coords], axis=1) # (N, 2)
|
236 |
|
237 |
# 获取每个点的置信度分数
|
238 |
scores = np.max(simcc_x[0], axis=1) * np.max(simcc_y[0], axis=1)
|
|
|
358 |
if not is_rgb:
|
359 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
360 |
|
361 |
+
colors = self.bone_colors
|
|
|
362 |
|
363 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
364 |
if score > 0.3: # 设置置信度阈值
|
|
|
370 |
# 添加关键点索引标注
|
371 |
cv2.putText(img, self.bone_names[i], (x+5, y+5),
|
372 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
373 |
+
|
374 |
+
# 绘制 关节连接线
|
375 |
+
for link in self.skeleton_links:
|
376 |
+
start_bone, end_bone = link.split("-")
|
377 |
+
|
378 |
+
start_index = self.bone_names.index(start_bone)
|
379 |
+
end_index = self.bone_names.index(end_bone)
|
380 |
+
|
381 |
+
start_keypoint = keypoints[start_index]
|
382 |
+
end_keypoint = keypoints[end_index]
|
383 |
+
link_color = colors[start_index]
|
384 |
+
|
385 |
+
# 绘制连线
|
386 |
+
if scores[start_index] > 0.3 and scores[end_index] > 0.3:
|
387 |
+
start_point = tuple(map(int, start_keypoint))
|
388 |
+
end_point = tuple(map(int, end_keypoint))
|
389 |
+
cv2.line(img, start_point, end_point,
|
390 |
+
(int(link_color[0]), int(link_color[1]), int(link_color[2])),
|
391 |
+
thickness=2)
|
392 |
+
|
393 |
return img
|
394 |
|