yolo12138's picture
feat: 模型优化
ec71a6c
raw
history blame
3.93 kB
import gradio as gr
# import cv2
import os
from core.chessboard_detector import ChessboardDetector
detector = ChessboardDetector(
pose_model_path="onnx/pose/4_v3.onnx",
full_classifier_model_path="onnx/layout_recognition/nano_v1.onnx"
)
# 数据集路径
dict_cate_names = {
'.': '.',
'x': 'x',
'红帅': 'K',
'红士': 'A',
'红相': 'B',
'红马': 'N',
'红车': 'R',
'红炮': 'C',
'红兵': 'P',
'黑将': 'k',
'黑仕': 'a',
'黑象': 'b',
'黑傌': 'n',
'黑車': 'r',
'黑砲': 'c',
'黑卒': 'p',
}
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
### 构建 examples
def build_examples():
examples = []
# 读取 examples 目录下的所有图片
for file in os.listdir("examples"):
if file.endswith(".jpg") or file.endswith(".png"):
image_path = os.path.join("examples", file)
examples.append([image_path])
return examples
full_examples = build_examples()
with gr.Blocks(css="""
.image img {
max-height: 512px;
}
"""
) as demo:
gr.Markdown("""
## 棋盘检测, 棋子识别
features: 轻量化模型
x 表示 有遮挡位置
. 表示 棋盘上的普通交叉点
步骤:
1. 流程分成两步,第一步 keypoints 检测
2. 拉伸棋盘,并预测棋子
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
with gr.Column():
original_image_with_keypoints = gr.Image(
label="step1: 原图带关键点",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Row():
with gr.Column():
transformed_image = gr.Image(
label="step2: 拉伸棋盘",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Column():
use_time = gr.Textbox(
label="用时",
interactive=False,
visible=True,
)
layout_pred_info = gr.Dataframe(
label="棋子识别",
interactive=False,
visible=True,
)
with gr.Row():
with gr.Column():
gr.Examples(
full_examples, inputs=[image_input], label="示例图片", examples_per_page=15,)
def detect_chessboard(image):
if image is None:
return None, None, None, None
try:
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)
# 将 cells_labels 转换为 DataFrame
# cells_labels 通过 \n 分割
annotation_10_rows = [item for item in cells_labels_str.split("\n")]
# 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
annotation_arr_10_9 = [list(item) for item in annotation_10_rows]
# 将 棋子类别 转换为 中文
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9]
except Exception as e:
gr.Warning(f"检测失败 图片或者视频布局错误")
return None, None, None, None
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info
image_input.change(fn=detect_chessboard,
inputs=[image_input],
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
if __name__ == "__main__":
demo.launch()