yolo12138's picture
Add video file using Git LFS
9316eb4
raw
history blame
6.61 kB
import gradio as gr
import cv2
import os
from core.chessboard_detector import ChessboardDetector
detector = ChessboardDetector(
det_model_path="onnx/det/v1.onnx",
pose_model_path="onnx/pose/v1.onnx",
full_classifier_model_path="onnx/layout_recognition/v1.onnx"
)
# 数据集路径
dict_cate_names = {
'.': '.',
'x': 'x',
'红帅': 'K',
'红士': 'A',
'红相': 'B',
'红马': 'N',
'红车': 'R',
'红炮': 'C',
'红兵': 'P',
'黑将': 'k',
'黑仕': 'a',
'黑象': 'b',
'黑傌': 'n',
'黑車': 'r',
'黑砲': 'c',
'黑卒': 'p',
}
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
### 构建 examples
def build_examples():
examples = []
# 读取 examples 目录下的所有图片
for file in os.listdir("examples"):
if file.endswith(".jpg"):
image_path = os.path.join("examples", file)
examples.append([image_path, None])
elif file.endswith(".mp4"):
video_path = os.path.join("examples", file)
examples.append([None, video_path])
return examples
full_examples = build_examples()
def get_video_frame_with_processs(video_data, process: str = '00:00') -> cv2.UMat:
"""
获取视频指定位置的帧
"""
# 读取视频
cap = cv2.VideoCapture(video_data)
if not cap.isOpened():
gr.Warning("无法打开视频")
return None
# 获取视频的帧率
fps = cap.get(cv2.CAP_PROP_FPS)
# process 是 00:00
process_time = process.split(":")
minutes = int(process_time[0])
seconds = float(process_time[1])
# 计算总秒数
target_seconds = minutes * 60 + seconds
# 计算当前帧
current_frame = int(target_seconds * fps)
# 设置到指定帧
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
# 读取当前帧
ret, frame = cap.read()
cap.release()
if not ret:
gr.Warning("无法读取视频帧")
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
with gr.Blocks(
js="""
async () => {
document.addEventListener('timeupdate', function(e) {
// 检查事件源是否是视频元素
if (e.target.matches('#video_player video')) {
const video = e.target;
const currentTime = video.currentTime;
// 转换成 00:00 格式
let minutes = Math.floor(currentTime / 60);
let seconds = Math.floor(currentTime % 60);
let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`;
// 更新输入框值
let processInput = document.querySelector("#video_process textarea");
if(processInput) {
processInput.value = formattedTime;
processInput.text = formattedTime;
processInput.dispatchEvent(new Event("input"));
}
}
}, true); // 使用捕获阶段
}
""",
css="""
.image img {
max-height: 512px;
}
"""
) as demo:
gr.Markdown("""
## 棋盘检测, 棋子识别
x 表示 有遮挡位置
. 表示 棋盘上的普通交叉点
步骤:
1. 流程分成两步,第一步检测边缘
2. 对整个棋盘画面进行棋子分类预测
"""
)
with gr.Row():
with gr.Column():
video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356)
video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00")
extract_frame_btn = gr.Button("从视频提取当前帧")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
with gr.Column():
original_image_with_keypoints = gr.Image(
label="step1: 原图带关键点",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Row():
with gr.Column():
transformed_image = gr.Image(
label="step2: 拉伸棋盘",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Column():
use_time = gr.Textbox(
label="用时",
interactive=False,
visible=True,
)
layout_pred_info = gr.Dataframe(
label="棋子识别",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column():
gr.Examples(full_examples, inputs=[image_input, video_input], label="示例视频、图片")
def detect_chessboard(image):
if image is None:
return None, None, None, None
try:
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)
# 将 cells_labels 转换为 DataFrame
# cells_labels 通过 \n 分割
annotation_10_rows = [item for item in cells_labels_str.split("\n")]
# 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
annotation_arr_10_9 = [list(item) for item in annotation_10_rows]
# 将 棋子类别 转换为 中文
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9]
except Exception as e:
gr.Warning(f"检测失败 图片或者视频布局错误")
return None, None, None, None
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info
image_input.change(fn=detect_chessboard,
inputs=[image_input],
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
extract_frame_btn.click(fn=get_video_frame_with_processs,
inputs=[video_input, video_process],
outputs=[image_input])
if __name__ == "__main__":
demo.launch()