|
import gradio as gr |
|
import cv2 |
|
import os |
|
from core.chessboard_detector import ChessboardDetector |
|
|
|
detector = ChessboardDetector( |
|
det_model_path="onnx/det/v1.onnx", |
|
pose_model_path="onnx/pose/v1.onnx", |
|
full_classifier_model_path="onnx/layout_recognition/v1.onnx" |
|
) |
|
|
|
|
|
dict_cate_names = { |
|
'.': '.', |
|
'x': 'x', |
|
'红帅': 'K', |
|
'红士': 'A', |
|
'红相': 'B', |
|
'红马': 'N', |
|
'红车': 'R', |
|
'红炮': 'C', |
|
'红兵': 'P', |
|
|
|
'黑将': 'k', |
|
'黑仕': 'a', |
|
'黑象': 'b', |
|
'黑傌': 'n', |
|
'黑車': 'r', |
|
'黑砲': 'c', |
|
'黑卒': 'p', |
|
} |
|
|
|
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()} |
|
|
|
|
|
|
|
|
|
def build_examples(): |
|
examples = [] |
|
|
|
for file in os.listdir("examples"): |
|
if file.endswith(".jpg"): |
|
image_path = os.path.join("examples", file) |
|
examples.append([image_path, None]) |
|
|
|
elif file.endswith(".mp4"): |
|
video_path = os.path.join("examples", file) |
|
examples.append([None, video_path]) |
|
|
|
return examples |
|
|
|
|
|
full_examples = build_examples() |
|
|
|
|
|
def get_video_frame_with_processs(video_data, process: str = '00:00') -> cv2.UMat: |
|
""" |
|
获取视频指定位置的帧 |
|
""" |
|
|
|
|
|
cap = cv2.VideoCapture(video_data) |
|
if not cap.isOpened(): |
|
gr.Warning("无法打开视频") |
|
return None |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
process_time = process.split(":") |
|
minutes = int(process_time[0]) |
|
seconds = float(process_time[1]) |
|
|
|
|
|
target_seconds = minutes * 60 + seconds |
|
|
|
|
|
current_frame = int(target_seconds * fps) |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame) |
|
|
|
|
|
ret, frame = cap.read() |
|
cap.release() |
|
|
|
if not ret: |
|
gr.Warning("无法读取视频帧") |
|
return None |
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
return frame_rgb |
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
js=""" |
|
async () => { |
|
document.addEventListener('timeupdate', function(e) { |
|
// 检查事件源是否是视频元素 |
|
if (e.target.matches('#video_player video')) { |
|
const video = e.target; |
|
const currentTime = video.currentTime; |
|
// 转换成 00:00 格式 |
|
let minutes = Math.floor(currentTime / 60); |
|
let seconds = Math.floor(currentTime % 60); |
|
let formattedTime = `${minutes.toString().padStart(2,'0')}:${seconds.toString().padStart(2,'0')}`; |
|
|
|
// 更新输入框值 |
|
let processInput = document.querySelector("#video_process textarea"); |
|
if(processInput) { |
|
processInput.value = formattedTime; |
|
processInput.text = formattedTime; |
|
|
|
processInput.dispatchEvent(new Event("input")); |
|
} |
|
|
|
} |
|
}, true); // 使用捕获阶段 |
|
} |
|
""", |
|
css=""" |
|
.image img { |
|
max-height: 512px; |
|
} |
|
""" |
|
) as demo: |
|
gr.Markdown(""" |
|
## 棋盘检测, 棋子识别 |
|
|
|
x 表示 有遮挡位置 |
|
. 表示 棋盘上的普通交叉点 |
|
|
|
步骤: |
|
1. 流程分成两步,第一步检测边缘 |
|
2. 对整个棋盘画面进行棋子分类预测 |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(label="上传视频", interactive=True, elem_id="video_player", height=356) |
|
video_process = gr.Textbox(label="当前时间", interactive=True, elem_id="video_process", value="00:00") |
|
extract_frame_btn = gr.Button("从视频提取当前帧") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image") |
|
|
|
with gr.Column(): |
|
original_image_with_keypoints = gr.Image( |
|
label="step1: 原图带关键点", |
|
interactive=False, |
|
visible=True, |
|
elem_classes="image" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
transformed_image = gr.Image( |
|
label="step2: 拉伸棋盘", |
|
interactive=False, |
|
visible=True, |
|
elem_classes="image" |
|
) |
|
|
|
with gr.Column(): |
|
use_time = gr.Textbox( |
|
label="用时", |
|
interactive=False, |
|
visible=True, |
|
) |
|
layout_pred_info = gr.Dataframe( |
|
label="棋子识别", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Examples(full_examples, inputs=[image_input, video_input], label="示例视频、图片") |
|
|
|
|
|
def detect_chessboard(image): |
|
if image is None: |
|
return None, None, None, None |
|
|
|
try: |
|
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image) |
|
|
|
|
|
|
|
annotation_10_rows = [item for item in cells_labels_str.split("\n")] |
|
|
|
annotation_arr_10_9 = [list(item) for item in annotation_10_rows] |
|
|
|
|
|
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9] |
|
|
|
except Exception as e: |
|
gr.Warning(f"检测失败 图片或者视频布局错误") |
|
return None, None, None, None |
|
|
|
|
|
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info |
|
|
|
image_input.change(fn=detect_chessboard, |
|
inputs=[image_input], |
|
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time]) |
|
|
|
extract_frame_btn.click(fn=get_video_frame_with_processs, |
|
inputs=[video_input, video_process], |
|
outputs=[image_input]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|