yolo12138 commited on
Commit
ec71a6c
·
1 Parent(s): 1e43daa

feat: 模型优化

Browse files
HISTORY.md CHANGED
@@ -11,3 +11,8 @@
11
 
12
  1. 使用 4 个关键点检测
13
 
 
 
 
 
 
 
11
 
12
  1. 使用 4 个关键点检测
13
 
14
+
15
+ ### 2025-01-25
16
+
17
+ 1. 修改 pose 模型, 不再需要 bbox 输入
18
+
app.py CHANGED
@@ -4,8 +4,7 @@ import os
4
  from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
- det_model_path="onnx/det/v2.onnx",
8
- pose_model_path="onnx/pose/4_v2.onnx",
9
  full_classifier_model_path="onnx/layout_recognition/nano_v1.onnx"
10
  )
11
 
@@ -58,15 +57,14 @@ with gr.Blocks(css="""
58
  gr.Markdown("""
59
  ## 棋盘检测, 棋子识别
60
 
 
 
61
  x 表示 有遮挡位置
62
  . 表示 棋盘上的普通交叉点
63
 
64
  步骤:
65
- 1. 流程分成两步,第一步检测边缘
66
- 2. 对整个棋盘画面进行棋子分类预测
67
-
68
- ### log
69
- 2025-01-24 模型优化 200M -> 30M
70
  """
71
  )
72
  with gr.Row():
@@ -105,7 +103,8 @@ with gr.Blocks(css="""
105
 
106
  with gr.Row():
107
  with gr.Column():
108
- gr.Examples(full_examples, inputs=[image_input], label="示例视频、图片")
 
109
 
110
 
111
  def detect_chessboard(image):
 
4
  from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
+ pose_model_path="onnx/pose/4_v3.onnx",
 
8
  full_classifier_model_path="onnx/layout_recognition/nano_v1.onnx"
9
  )
10
 
 
57
  gr.Markdown("""
58
  ## 棋盘检测, 棋子识别
59
 
60
+ features: 轻量化模型
61
+
62
  x 表示 有遮挡位置
63
  . 表示 棋盘上的普通交叉点
64
 
65
  步骤:
66
+ 1. 流程分成两步,第一步 keypoints 检测
67
+ 2. 拉伸棋盘,并预测棋子
 
 
 
68
  """
69
  )
70
  with gr.Row():
 
103
 
104
  with gr.Row():
105
  with gr.Column():
106
+ gr.Examples(
107
+ full_examples, inputs=[image_input], label="示例图片", examples_per_page=15,)
108
 
109
 
110
  def detect_chessboard(image):
core/chessboard_detector.py CHANGED
@@ -4,32 +4,24 @@ import numpy as np
4
  import cv2
5
  from typing import List, Tuple, Union
6
  from pandas import DataFrame
7
- from .runonnx.rtmdet import RTMDET_ONNX
8
  from .runonnx.rtmpose import RTMPOSE_ONNX
9
  from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
10
 
11
  from core.helper_4_kpt import extract_chessboard
12
 
13
  class ChessboardDetector:
14
- def __init__(self,
15
- det_model_path: str,
16
  pose_model_path: str,
17
  full_classifier_model_path: str = None
18
  ):
19
 
20
- self.det = RTMDET_ONNX(
21
- model_path=det_model_path,
22
- )
23
-
24
-
25
  self.pose = RTMPOSE_ONNX(
26
  model_path=pose_model_path,
27
  )
28
 
29
- if full_classifier_model_path is not None:
30
- self.full_classifier = FULL_CLASSIFIER_ONNX(
31
- model_path=full_classifier_model_path,
32
- )
33
 
34
  self.board_positions = [] # 存储棋盘位置坐标
35
  self.current_image = None
@@ -37,19 +29,21 @@ class ChessboardDetector:
37
 
38
 
39
  # 检测中国象棋棋盘
40
- def pred_detect_and_keypoints(self, image_bgr: Union[np.ndarray, None] = None) -> Tuple[List[int], float, List[List[int]], List[float]]:
41
-
42
- xyxy, conf = self.det.pred(image_bgr)
43
 
44
  # 预测关键点, 绘制关键点
45
- keypoints, scores = self.pose.pred(image=image_bgr, bbox=xyxy)
46
 
47
- return xyxy, conf, keypoints, scores
 
 
 
 
 
48
 
49
 
50
  def draw_pred_with_keypoints(self, image_rgb: Union[np.ndarray, None] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
51
  if image_rgb is None:
52
- return None, None, None, None
53
 
54
  image_rgb = image_rgb.copy()
55
 
@@ -57,13 +51,10 @@ class ChessboardDetector:
57
 
58
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
59
 
60
- xyxy, conf, keypoints, scores = self.pred_detect_and_keypoints(image_bgr)
61
 
62
  # 绘制棋盘框架
63
- draw_image = self.det.draw_pred(image_rgb, xyxy, conf)
64
-
65
- # 绘制关键点
66
- draw_image = self.pose.draw_pred(img=draw_image, keypoints=keypoints, scores=scores)
67
 
68
  # 融合 self.pose.bone_names 与 keypoints, 再转换成 DataFrame
69
  keypoint_list = []
@@ -72,7 +63,7 @@ class ChessboardDetector:
72
 
73
  keypoint_df = DataFrame(keypoint_list)
74
 
75
- return draw_image, original_image, [xyxy], keypoint_df
76
 
77
  # 拉伸棋盘 detect board, 然后预测
78
  def extract_chessboard_and_classifier_layout(self,
@@ -111,22 +102,24 @@ class ChessboardDetector:
111
  return None, None, [], [], ""
112
 
113
  image_rgb_for_extract = image_rgb.copy()
 
114
 
115
  start_time = time.time()
116
 
117
- image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
118
-
119
- xyxy, conf, keypoints, scores = self.pred_detect_and_keypoints(image_bgr)
120
 
121
- # 绘制棋盘框架
122
- draw_image = self.det.draw_pred(image_rgb, xyxy, conf)
123
 
124
- """
125
- 绘制 原图关键点
126
- """
127
- original_image_with_keypoints = self.pose.draw_pred(img=draw_image, keypoints=keypoints, scores=scores)
128
 
129
- transformed_image, cells_labels, scores = self.extract_chessboard_and_classifier_layout(image_rgb=image_rgb_for_extract, keypoints=keypoints)
 
 
 
130
 
131
 
132
  use_time = time.time() - start_time
 
4
  import cv2
5
  from typing import List, Tuple, Union
6
  from pandas import DataFrame
 
7
  from .runonnx.rtmpose import RTMPOSE_ONNX
8
  from .runonnx.full_classifier import FULL_CLASSIFIER_ONNX
9
 
10
  from core.helper_4_kpt import extract_chessboard
11
 
12
  class ChessboardDetector:
13
+ def __init__(self,
 
14
  pose_model_path: str,
15
  full_classifier_model_path: str = None
16
  ):
17
 
 
 
 
 
 
18
  self.pose = RTMPOSE_ONNX(
19
  model_path=pose_model_path,
20
  )
21
 
22
+ self.full_classifier = FULL_CLASSIFIER_ONNX(
23
+ model_path=full_classifier_model_path,
24
+ )
 
25
 
26
  self.board_positions = [] # 存储棋盘位置坐标
27
  self.current_image = None
 
29
 
30
 
31
  # 检测中国象棋棋盘
32
+ def pred_keypoints(self, image_bgr: Union[np.ndarray, None] = None) -> Tuple[List[List[int]], List[float]]:
 
 
33
 
34
  # 预测关键点, 绘制关键点
 
35
 
36
+ width, height = image_bgr.shape[:2]
37
+ bbox = [0, 0, width, height]
38
+
39
+ keypoints, scores = self.pose.pred(image=image_bgr, bbox=bbox)
40
+
41
+ return keypoints, scores
42
 
43
 
44
  def draw_pred_with_keypoints(self, image_rgb: Union[np.ndarray, None] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
45
  if image_rgb is None:
46
+ return None, None, None
47
 
48
  image_rgb = image_rgb.copy()
49
 
 
51
 
52
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
53
 
54
+ keypoints, scores = self.pred_keypoints(image_bgr)
55
 
56
  # 绘制棋盘框架
57
+ draw_image = self.pose.draw_pred(img=image_rgb, keypoints=keypoints, scores=scores)
 
 
 
58
 
59
  # 融合 self.pose.bone_names 与 keypoints, 再转换成 DataFrame
60
  keypoint_list = []
 
63
 
64
  keypoint_df = DataFrame(keypoint_list)
65
 
66
+ return draw_image, original_image, keypoint_df
67
 
68
  # 拉伸棋盘 detect board, 然后预测
69
  def extract_chessboard_and_classifier_layout(self,
 
102
  return None, None, [], [], ""
103
 
104
  image_rgb_for_extract = image_rgb.copy()
105
+ image_rgb_for_draw = image_rgb.copy()
106
 
107
  start_time = time.time()
108
 
109
+ try:
110
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
 
111
 
112
+ keypoints, scores = self.pred_keypoints(image_bgr)
 
113
 
114
+ """
115
+ 绘制 原图关键点
116
+ """
117
+ original_image_with_keypoints = self.pose.draw_pred(img=image_rgb_for_draw, keypoints=keypoints, scores=scores)
118
 
119
+ transformed_image, cells_labels, scores = self.extract_chessboard_and_classifier_layout(image_rgb=image_rgb_for_extract, keypoints=keypoints)
120
+ except Exception as e:
121
+ print("检测棋盘失败", e)
122
+ return None, None, None, None, ""
123
 
124
 
125
  use_time = time.time() - start_time
core/runonnx/rtmdet.py DELETED
@@ -1,117 +0,0 @@
1
-
2
-
3
- import numpy as np
4
- import cv2
5
- from typing import Tuple, List, Union
6
- from .base_onnx import BaseONNX
7
-
8
-
9
- class RTMDET_ONNX(BaseONNX):
10
-
11
- def __init__(self, model_path, input_size=(640, 640)):
12
- super().__init__(model_path, input_size)
13
-
14
- def preprocess_image(self, img_bgr: cv2.UMat):
15
- # 调整图片大小
16
- img_bgr = cv2.resize(img_bgr, self.input_size)
17
-
18
- # normalize mean and std
19
- img = (img_bgr - np.array([103.53, 116.28, 123.675])) / np.array([57.375, 57.12, 58.395])
20
-
21
- img = img.astype(np.float32)
22
- # 转换为浮点型并归一化
23
- # img = img.astype(np.float32) / 255.0
24
-
25
- # 调整维度顺序 (H,W,C) -> (C,H,W)
26
- img = np.transpose(img, (2, 0, 1))
27
-
28
- # 添加 batch 维度
29
- img = np.expand_dims(img, axis=0)
30
-
31
- return img
32
-
33
-
34
- def run_inference(self, image: np.ndarray):
35
- """
36
- Run inference on the image.
37
-
38
- Args:
39
- image (np.ndarray): The image to run inference on.
40
-
41
- Returns:
42
- tuple: A tuple containing the detection results and labels.
43
- """
44
- # 运行推理
45
- outputs = self.session.run(None, {self.input_name: image})
46
-
47
- """
48
- dets: 检测框 [batch, num_dets, [x1, y1, x2, y2, conf]] ([batch, num_dets, Reshape(dets_dim_2)])
49
- labels: 标签 [batch,num_dets]
50
- """
51
- dets, labels = outputs
52
-
53
- return dets, labels
54
-
55
- def pred(self, image: List[Union[cv2.UMat, str]]) -> Tuple[List[int], float]:
56
- """
57
- Predict the detection results of the image.
58
-
59
- Args:
60
- image (cv2.UMat, str): The image to predict.
61
-
62
- Returns:
63
- xyxy (list[int, int, int, int]): The detection results.
64
- conf (float): The confidence of the detection results.
65
- """
66
- if isinstance(image, str):
67
- img_bgr = cv2.imread(image)
68
- else:
69
- img_bgr = image.copy()
70
-
71
- original_w, original_h = img_bgr.shape[1], img_bgr.shape[0]
72
-
73
- image = self.preprocess_image(img_bgr)
74
- dets, labels = self.run_inference(image)
75
-
76
- # 获取置信度最高的检测框
77
- # dets = dets[0][0]
78
- # labels = labels[0][0]
79
-
80
- x1, y1, x2, y2, conf = dets[0][0]
81
-
82
- xyxy = [x1, y1, x2, y2]
83
-
84
- xyxy = self.transform_xyxy_to_original(xyxy, original_w, original_h)
85
-
86
- return xyxy, conf
87
-
88
- def transform_xyxy_to_original(self, xyxy, original_w, original_h) -> List[int]:
89
- """
90
- 将检测框从输入图像的尺寸转换为原始图像的尺寸
91
- """
92
- x1, y1, x2, y2 = xyxy
93
-
94
- input_w, input_h = self.input_size
95
- ratio_w, ratio_h = original_w / input_w, original_h / input_h
96
-
97
- x1, y1, x2, y2 = x1 * ratio_w, y1 * ratio_h, x2 * ratio_w, y2 * ratio_h
98
- # 转换为整数
99
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
100
-
101
- return [x1, y1, x2, y2]
102
-
103
- def draw_pred(self, img: cv2.UMat, xyxy: List[int], conf: float, is_rgb: bool = True) -> cv2.UMat:
104
- """
105
- Draw the detection results on the image.
106
- """
107
-
108
- if not is_rgb:
109
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
-
111
- x1, y1, x2, y2 = xyxy
112
-
113
- cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
114
- cv2.putText(img, f"{conf:.2f}", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
115
-
116
- return img
117
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/runonnx/rtmpose.py CHANGED
@@ -378,7 +378,7 @@ class RTMPOSE_ONNX(BaseONNX):
378
  else:
379
  text = f"{self.bone_names[i]}"
380
  cv2.putText(img, text, (x+5, y+5),
381
- cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
382
 
383
  # 绘制 关节连接线
384
  for link in self.skeleton_links:
 
378
  else:
379
  text = f"{self.bone_names[i]}"
380
  cv2.putText(img, text, (x+5, y+5),
381
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 2)
382
 
383
  # 绘制 关节连接线
384
  for link in self.skeleton_links: