Spaces:
Sleeping
Sleeping
import gradio as gr | |
from core.chessboard_detector import ChessboardDetector | |
detector = ChessboardDetector( | |
det_model_path="onnx/det/v1.onnx", | |
pose_model_path="onnx/pose/v1.onnx", | |
full_classifier_model_path="onnx/layout_recognition/v1.onnx" | |
) | |
# 数据集路径 | |
dict_cate_names = { | |
'.': '.', | |
'x': 'x', | |
'红帅': 'K', | |
'红士': 'A', | |
'红相': 'B', | |
'红马': 'N', | |
'红车': 'R', | |
'红炮': 'C', | |
'红兵': 'P', | |
'黑将': 'k', | |
'黑仕': 'a', | |
'黑象': 'b', | |
'黑傌': 'n', | |
'黑車': 'r', | |
'黑砲': 'c', | |
'黑卒': 'p', | |
} | |
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()} | |
with gr.Blocks( | |
css=""" | |
.image { | |
max-height: 512px; | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
## 棋盘检测, 棋子识别 | |
步骤: | |
1. 流程分成两步,第一步检测边缘 | |
2. 对整个棋盘画面进行棋子分类预测 | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image") | |
with gr.Column(): | |
original_image_with_keypoints = gr.Image( | |
label="step1: 原图带关键点", | |
interactive=False, | |
visible=True, | |
elem_classes="image" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
transformed_image = gr.Image( | |
label="step2: 拉伸棋盘", | |
interactive=False, | |
visible=True, | |
elem_classes="image" | |
) | |
with gr.Column(): | |
use_time = gr.Textbox( | |
label="用时", | |
interactive=False, | |
visible=True, | |
) | |
layout_pred_info = gr.Dataframe( | |
label="棋子识别", | |
interactive=False, | |
visible=True, | |
) | |
def detect_chessboard(image): | |
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image) | |
# 将 cells_labels 转换为 DataFrame | |
# cells_labels 通过 \n 分割 | |
annotation_10_rows = [item for item in cells_labels_str.split("\n")] | |
# 将 annotation_10_rows 转换成为 10 行 9 列的二维数组 | |
annotation_arr_10_9 = [list(item) for item in annotation_10_rows] | |
# 将 棋子类别 转换为 中文 | |
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9] | |
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info | |
image_input.change(fn=detect_chessboard, | |
inputs=[image_input], | |
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time]) | |
if __name__ == "__main__": | |
demo.launch() | |