yolo12138 commited on
Commit
5873e33
·
1 Parent(s): e42b155

feat: 性能更新

Browse files
HISTORY.md CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  ### 2025-01-05
2
 
3
  1. 使用 swinv2 tiny 模型
 
1
+ ### 2025-01-10
2
+
3
+ 1. 移除 视频
4
+ 2. 增加图片
5
+
6
  ### 2025-01-05
7
 
8
  1. 使用 swinv2 tiny 模型
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
- import cv2
3
  import os
4
  from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
- det_model_path="onnx/det/v1.onnx",
8
  pose_model_path="onnx/pose/4_v2.onnx",
9
- full_classifier_model_path="onnx/layout_recognition/v3.onnx"
10
  )
11
 
12
  # 数据集路径
@@ -39,13 +39,9 @@ def build_examples():
39
  examples = []
40
  # 读取 examples 目录下的所有图片
41
  for file in os.listdir("examples"):
42
- if file.endswith(".jpg"):
43
  image_path = os.path.join("examples", file)
44
- examples.append([image_path, None])
45
-
46
- elif file.endswith(".mp4"):
47
- video_path = os.path.join("examples", file)
48
- examples.append([None, video_path])
49
 
50
  return examples
51
 
@@ -53,76 +49,7 @@ def build_examples():
53
  full_examples = build_examples()
54
 
55
 
56
- def get_video_frame_with_processs(video_data, process: str = '00:00') -> cv2.UMat:
57
- """
58
- 获取视频指定位置的帧
59
- """
60
-
61
- # 读取视频
62
- cap = cv2.VideoCapture(video_data)
63
- if not cap.isOpened():
64
- gr.Warning("无法打开视频")
65
- return None
66
-
67
- # 获取视频的帧率
68
- fps = cap.get(cv2.CAP_PROP_FPS)
69
-
70
- # process 是 00:00
71
- process_time = process.split(":")
72
- minutes = int(process_time[0])
73
- seconds = float(process_time[1])
74
-
75
- # 计算总秒数
76
- target_seconds = minutes * 60 + seconds
77
-
78
- # 计算当前帧
79
- current_frame = int(target_seconds * fps)
80
-
81
- # 设置到指定帧
82
- cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
83
-
84
- # 读取当前帧
85
- ret, frame = cap.read()
86
- cap.release()
87
-
88
- if not ret:
89
- gr.Warning("无法读取视频帧")
90
- return None
91
-
92
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
93
-
94
- return frame_rgb
95
-
96
-
97
-
98
-
99
- with gr.Blocks(
100
- js="""
101
- async () => {
102
- document.addEventListener('timeupdate', function(e) {
103
- // 检查事件源是否是视频元素
104
- if (e.target.matches('#video_player video')) {
105
- const video = e.target;
106
- const currentTime = video.currentTime;
107
- // 转换成 00:00 格式
108
- let minutes = Math.floor(currentTime / 60);
109
- let seconds = Math.floor(currentTime % 60);
110
- let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`;
111
-
112
- // 更新输入框值
113
- let processInput = document.querySelector("#video_process textarea");
114
- if(processInput) {
115
- processInput.value = formattedTime;
116
- processInput.text = formattedTime;
117
-
118
- processInput.dispatchEvent(new Event("input"));
119
- }
120
-
121
- }
122
- }, true); // 使用捕获阶段
123
- }
124
- """,
125
- css="""
126
  .image img {
127
  max-height: 512px;
128
  }
@@ -139,13 +66,6 @@ with gr.Blocks(
139
  2. 对整个棋盘画面进行棋子分类预测
140
  """
141
  )
142
-
143
- with gr.Row():
144
- with gr.Column():
145
- video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356)
146
- video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00")
147
- extract_frame_btn = gr.Button("从视频提取当前帧")
148
-
149
  with gr.Row():
150
  with gr.Column():
151
  image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
@@ -182,7 +102,7 @@ with gr.Blocks(
182
 
183
  with gr.Row():
184
  with gr.Column():
185
- gr.Examples(full_examples, inputs=[image_input, video_input], label="示例视频、图片")
186
 
187
 
188
  def detect_chessboard(image):
@@ -212,9 +132,5 @@ with gr.Blocks(
212
  inputs=[image_input],
213
  outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
214
 
215
- extract_frame_btn.click(fn=get_video_frame_with_processs,
216
- inputs=[video_input, video_process],
217
- outputs=[image_input])
218
-
219
  if __name__ == "__main__":
220
  demo.launch()
 
1
  import gradio as gr
2
+ # import cv2
3
  import os
4
  from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
+ det_model_path="onnx/det/v2.onnx",
8
  pose_model_path="onnx/pose/4_v2.onnx",
9
+ full_classifier_model_path="onnx/layout_recognition/v5.onnx"
10
  )
11
 
12
  # 数据集路径
 
39
  examples = []
40
  # 读取 examples 目录下的所有图片
41
  for file in os.listdir("examples"):
42
+ if file.endswith(".jpg") or file.endswith(".png"):
43
  image_path = os.path.join("examples", file)
44
+ examples.append([image_path])
 
 
 
 
45
 
46
  return examples
47
 
 
49
  full_examples = build_examples()
50
 
51
 
52
+ with gr.Blocks(css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  .image img {
54
  max-height: 512px;
55
  }
 
66
  2. 对整个棋盘画面进行棋子分类预测
67
  """
68
  )
 
 
 
 
 
 
 
69
  with gr.Row():
70
  with gr.Column():
71
  image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
 
102
 
103
  with gr.Row():
104
  with gr.Column():
105
+ gr.Examples(full_examples, inputs=[image_input], label="示例视频、图片")
106
 
107
 
108
  def detect_chessboard(image):
 
132
  inputs=[image_input],
133
  outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
134
 
 
 
 
 
135
  if __name__ == "__main__":
136
  demo.launch()
core/runonnx/rtmpose.py CHANGED
@@ -350,7 +350,12 @@ class RTMPOSE_ONNX(BaseONNX):
350
 
351
  return original_keypoints
352
 
353
- def draw_pred(self, img: cv2.UMat, keypoints: np.ndarray, scores: np.ndarray, is_rgb: bool = True) -> cv2.UMat:
 
 
 
 
 
354
  """
355
  Draw the keypoints results on the image.
356
  """
@@ -361,14 +366,18 @@ class RTMPOSE_ONNX(BaseONNX):
361
  colors = self.bone_colors
362
 
363
  for i, (point, score) in enumerate(zip(keypoints, scores)):
364
- if score > 0.3: # 设置置信度阈值
365
  x, y = map(int, point)
366
  # 使用不同颜色标注不同的关键点
367
  color = colors[i]
368
 
369
  cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
370
  # 添加关键点索引标注
371
- cv2.putText(img, self.bone_names[i], (x+5, y+5),
 
 
 
 
372
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
373
 
374
  # 绘制 关节连接线
@@ -383,7 +392,7 @@ class RTMPOSE_ONNX(BaseONNX):
383
  link_color = colors[start_index]
384
 
385
  # 绘制连线
386
- if scores[start_index] > 0.3 and scores[end_index] > 0.3:
387
  start_point = tuple(map(int, start_keypoint))
388
  end_point = tuple(map(int, end_keypoint))
389
  cv2.line(img, start_point, end_point,
 
350
 
351
  return original_keypoints
352
 
353
+ def draw_pred(self,
354
+ img: cv2.UMat,
355
+ keypoints: np.ndarray,
356
+ scores: np.ndarray,
357
+ is_rgb: bool = True,
358
+ score_threshold: float = 0.6) -> cv2.UMat:
359
  """
360
  Draw the keypoints results on the image.
361
  """
 
366
  colors = self.bone_colors
367
 
368
  for i, (point, score) in enumerate(zip(keypoints, scores)):
369
+
370
  x, y = map(int, point)
371
  # 使用不同颜色标注不同的关键点
372
  color = colors[i]
373
 
374
  cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
375
  # 添加关键点索引标注
376
+ if score < score_threshold: # 设置置信度阈值
377
+ text = f"{self.bone_names[i]}: {score:.2f}"
378
+ else:
379
+ text = f"{self.bone_names[i]}"
380
+ cv2.putText(img, text, (x+5, y+5),
381
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
382
 
383
  # 绘制 关节连接线
 
392
  link_color = colors[start_index]
393
 
394
  # 绘制连线
395
+ if scores[start_index] > score_threshold and scores[end_index] > score_threshold:
396
  start_point = tuple(map(int, start_keypoint))
397
  end_point = tuple(map(int, end_keypoint))
398
  cv2.line(img, start_point, end_point,
examples/demo001.png ADDED
examples/demo002.png ADDED
examples/demo003.png ADDED
examples/demo004.png ADDED
examples/demo005.png ADDED