feat: 性能更新
Browse files- HISTORY.md +5 -0
- app.py +7 -91
- core/runonnx/rtmpose.py +13 -4
- examples/demo001.png +0 -0
- examples/demo002.png +0 -0
- examples/demo003.png +0 -0
- examples/demo004.png +0 -0
- examples/demo005.png +0 -0
HISTORY.md
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
### 2025-01-05
|
2 |
|
3 |
1. 使用 swinv2 tiny 模型
|
|
|
1 |
+
### 2025-01-10
|
2 |
+
|
3 |
+
1. 移除 视频
|
4 |
+
2. 增加图片
|
5 |
+
|
6 |
### 2025-01-05
|
7 |
|
8 |
1. 使用 swinv2 tiny 模型
|
app.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
import cv2
|
3 |
import os
|
4 |
from core.chessboard_detector import ChessboardDetector
|
5 |
|
6 |
detector = ChessboardDetector(
|
7 |
-
det_model_path="onnx/det/
|
8 |
pose_model_path="onnx/pose/4_v2.onnx",
|
9 |
-
full_classifier_model_path="onnx/layout_recognition/
|
10 |
)
|
11 |
|
12 |
# 数据集路径
|
@@ -39,13 +39,9 @@ def build_examples():
|
|
39 |
examples = []
|
40 |
# 读取 examples 目录下的所有图片
|
41 |
for file in os.listdir("examples"):
|
42 |
-
if file.endswith(".jpg"):
|
43 |
image_path = os.path.join("examples", file)
|
44 |
-
examples.append([image_path
|
45 |
-
|
46 |
-
elif file.endswith(".mp4"):
|
47 |
-
video_path = os.path.join("examples", file)
|
48 |
-
examples.append([None, video_path])
|
49 |
|
50 |
return examples
|
51 |
|
@@ -53,76 +49,7 @@ def build_examples():
|
|
53 |
full_examples = build_examples()
|
54 |
|
55 |
|
56 |
-
|
57 |
-
"""
|
58 |
-
获取视频指定位置的帧
|
59 |
-
"""
|
60 |
-
|
61 |
-
# 读取视频
|
62 |
-
cap = cv2.VideoCapture(video_data)
|
63 |
-
if not cap.isOpened():
|
64 |
-
gr.Warning("无法打开视频")
|
65 |
-
return None
|
66 |
-
|
67 |
-
# 获取视频的帧率
|
68 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
69 |
-
|
70 |
-
# process 是 00:00
|
71 |
-
process_time = process.split(":")
|
72 |
-
minutes = int(process_time[0])
|
73 |
-
seconds = float(process_time[1])
|
74 |
-
|
75 |
-
# 计算总秒数
|
76 |
-
target_seconds = minutes * 60 + seconds
|
77 |
-
|
78 |
-
# 计算当前帧
|
79 |
-
current_frame = int(target_seconds * fps)
|
80 |
-
|
81 |
-
# 设置到指定帧
|
82 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
|
83 |
-
|
84 |
-
# 读取当前帧
|
85 |
-
ret, frame = cap.read()
|
86 |
-
cap.release()
|
87 |
-
|
88 |
-
if not ret:
|
89 |
-
gr.Warning("无法读取视频帧")
|
90 |
-
return None
|
91 |
-
|
92 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
93 |
-
|
94 |
-
return frame_rgb
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
with gr.Blocks(
|
100 |
-
js="""
|
101 |
-
async () => {
|
102 |
-
document.addEventListener('timeupdate', function(e) {
|
103 |
-
// 检查事件源是否是视频元素
|
104 |
-
if (e.target.matches('#video_player video')) {
|
105 |
-
const video = e.target;
|
106 |
-
const currentTime = video.currentTime;
|
107 |
-
// 转换成 00:00 格式
|
108 |
-
let minutes = Math.floor(currentTime / 60);
|
109 |
-
let seconds = Math.floor(currentTime % 60);
|
110 |
-
let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`;
|
111 |
-
|
112 |
-
// 更新输入框值
|
113 |
-
let processInput = document.querySelector("#video_process textarea");
|
114 |
-
if(processInput) {
|
115 |
-
processInput.value = formattedTime;
|
116 |
-
processInput.text = formattedTime;
|
117 |
-
|
118 |
-
processInput.dispatchEvent(new Event("input"));
|
119 |
-
}
|
120 |
-
|
121 |
-
}
|
122 |
-
}, true); // 使用捕获阶段
|
123 |
-
}
|
124 |
-
""",
|
125 |
-
css="""
|
126 |
.image img {
|
127 |
max-height: 512px;
|
128 |
}
|
@@ -139,13 +66,6 @@ with gr.Blocks(
|
|
139 |
2. 对整个棋盘画面进行棋子分类预测
|
140 |
"""
|
141 |
)
|
142 |
-
|
143 |
-
with gr.Row():
|
144 |
-
with gr.Column():
|
145 |
-
video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356)
|
146 |
-
video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00")
|
147 |
-
extract_frame_btn = gr.Button("从视频提取当前帧")
|
148 |
-
|
149 |
with gr.Row():
|
150 |
with gr.Column():
|
151 |
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
|
@@ -182,7 +102,7 @@ with gr.Blocks(
|
|
182 |
|
183 |
with gr.Row():
|
184 |
with gr.Column():
|
185 |
-
gr.Examples(full_examples, inputs=[image_input
|
186 |
|
187 |
|
188 |
def detect_chessboard(image):
|
@@ -212,9 +132,5 @@ with gr.Blocks(
|
|
212 |
inputs=[image_input],
|
213 |
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
|
214 |
|
215 |
-
extract_frame_btn.click(fn=get_video_frame_with_processs,
|
216 |
-
inputs=[video_input, video_process],
|
217 |
-
outputs=[image_input])
|
218 |
-
|
219 |
if __name__ == "__main__":
|
220 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
# import cv2
|
3 |
import os
|
4 |
from core.chessboard_detector import ChessboardDetector
|
5 |
|
6 |
detector = ChessboardDetector(
|
7 |
+
det_model_path="onnx/det/v2.onnx",
|
8 |
pose_model_path="onnx/pose/4_v2.onnx",
|
9 |
+
full_classifier_model_path="onnx/layout_recognition/v5.onnx"
|
10 |
)
|
11 |
|
12 |
# 数据集路径
|
|
|
39 |
examples = []
|
40 |
# 读取 examples 目录下的所有图片
|
41 |
for file in os.listdir("examples"):
|
42 |
+
if file.endswith(".jpg") or file.endswith(".png"):
|
43 |
image_path = os.path.join("examples", file)
|
44 |
+
examples.append([image_path])
|
|
|
|
|
|
|
|
|
45 |
|
46 |
return examples
|
47 |
|
|
|
49 |
full_examples = build_examples()
|
50 |
|
51 |
|
52 |
+
with gr.Blocks(css="""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
.image img {
|
54 |
max-height: 512px;
|
55 |
}
|
|
|
66 |
2. 对整个棋盘画面进行棋子分类预测
|
67 |
"""
|
68 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
with gr.Row():
|
70 |
with gr.Column():
|
71 |
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
|
|
|
102 |
|
103 |
with gr.Row():
|
104 |
with gr.Column():
|
105 |
+
gr.Examples(full_examples, inputs=[image_input], label="示例视频、图片")
|
106 |
|
107 |
|
108 |
def detect_chessboard(image):
|
|
|
132 |
inputs=[image_input],
|
133 |
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
|
134 |
|
|
|
|
|
|
|
|
|
135 |
if __name__ == "__main__":
|
136 |
demo.launch()
|
core/runonnx/rtmpose.py
CHANGED
@@ -350,7 +350,12 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
350 |
|
351 |
return original_keypoints
|
352 |
|
353 |
-
def draw_pred(self,
|
|
|
|
|
|
|
|
|
|
|
354 |
"""
|
355 |
Draw the keypoints results on the image.
|
356 |
"""
|
@@ -361,14 +366,18 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
361 |
colors = self.bone_colors
|
362 |
|
363 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
364 |
-
|
365 |
x, y = map(int, point)
|
366 |
# 使用不同颜色标注不同的关键点
|
367 |
color = colors[i]
|
368 |
|
369 |
cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
|
370 |
# 添加关键点索引标注
|
371 |
-
|
|
|
|
|
|
|
|
|
372 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
373 |
|
374 |
# 绘制 关节连接线
|
@@ -383,7 +392,7 @@ class RTMPOSE_ONNX(BaseONNX):
|
|
383 |
link_color = colors[start_index]
|
384 |
|
385 |
# 绘制连线
|
386 |
-
if scores[start_index] >
|
387 |
start_point = tuple(map(int, start_keypoint))
|
388 |
end_point = tuple(map(int, end_keypoint))
|
389 |
cv2.line(img, start_point, end_point,
|
|
|
350 |
|
351 |
return original_keypoints
|
352 |
|
353 |
+
def draw_pred(self,
|
354 |
+
img: cv2.UMat,
|
355 |
+
keypoints: np.ndarray,
|
356 |
+
scores: np.ndarray,
|
357 |
+
is_rgb: bool = True,
|
358 |
+
score_threshold: float = 0.6) -> cv2.UMat:
|
359 |
"""
|
360 |
Draw the keypoints results on the image.
|
361 |
"""
|
|
|
366 |
colors = self.bone_colors
|
367 |
|
368 |
for i, (point, score) in enumerate(zip(keypoints, scores)):
|
369 |
+
|
370 |
x, y = map(int, point)
|
371 |
# 使用不同颜色标注不同的关键点
|
372 |
color = colors[i]
|
373 |
|
374 |
cv2.circle(img, (x, y), 5, (int(color[0]), int(color[1]), int(color[2])), -1)
|
375 |
# 添加关键点索引标注
|
376 |
+
if score < score_threshold: # 设置置信度阈值
|
377 |
+
text = f"{self.bone_names[i]}: {score:.2f}"
|
378 |
+
else:
|
379 |
+
text = f"{self.bone_names[i]}"
|
380 |
+
cv2.putText(img, text, (x+5, y+5),
|
381 |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (int(color[0]), int(color[1]), int(color[2])), 1)
|
382 |
|
383 |
# 绘制 关节连接线
|
|
|
392 |
link_color = colors[start_index]
|
393 |
|
394 |
# 绘制连线
|
395 |
+
if scores[start_index] > score_threshold and scores[end_index] > score_threshold:
|
396 |
start_point = tuple(map(int, start_keypoint))
|
397 |
end_point = tuple(map(int, end_keypoint))
|
398 |
cv2.line(img, start_point, end_point,
|
examples/demo001.png
ADDED
![]() |
examples/demo002.png
ADDED
![]() |
examples/demo003.png
ADDED
![]() |
examples/demo004.png
ADDED
![]() |
examples/demo005.png
ADDED
![]() |