yolo12138's picture
fix none result
4bcaf91
import gradio as gr
# import cv2
import os
import base64
from pathlib import Path
from core.chessboard_detector import ChessboardDetector
detector = ChessboardDetector(
pose_model_path="onnx/pose/4_v6-0301.onnx",
full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx"
)
# 数据集路径
dict_cate_names = {
'.': '.',
'x': 'x',
'红帅': 'K',
'红士': 'A',
'红相': 'B',
'红马': 'N',
'红车': 'R',
'红炮': 'C',
'红兵': 'P',
'黑将': 'k',
'黑仕': 'a',
'黑象': 'b',
'黑傌': 'n',
'黑車': 'r',
'黑砲': 'c',
'黑卒': 'p',
}
# 数据集路径
dict_cate_images = {
'K': 'red_K.png',
'A': 'red_A.png',
'B': 'red_B.png',
'N': 'red_N.png',
'R': 'red_R.png',
'C': 'red_C.png',
'P': 'red_P.png',
'k': 'black_k.png',
'a': 'black_a.png',
'b': 'black_b.png',
'n': 'black_n.png',
'r': 'black_r.png',
'c': 'black_c.png',
'p': 'black_p.png',
}
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
# 缓存图片的 base64 编码
image_base64_cache = {}
def get_image_base64(img_path):
if img_path in image_base64_cache:
return image_base64_cache[img_path]
try:
img_full_path = Path("resources") / img_path
if not img_full_path.exists():
return ""
with open(img_full_path, "rb") as img_file:
encoded = base64.b64encode(img_file.read()).decode('utf-8')
data_url = f"data:image/png;base64,{encoded}"
image_base64_cache[img_path] = data_url
return data_url
except Exception as e:
print(f"Error loading image {img_path}: {e}")
return
### 构建 examples
def build_examples():
examples = []
# 读取 examples 目录下的所有图片
for file in os.listdir("examples"):
if file.endswith(".jpg") or file.endswith(".png"):
image_path = os.path.join("examples", file)
examples.append([image_path])
return examples
full_examples = build_examples()
with gr.Blocks(css="""
.image img {
max-height: 512px;
}
"""
) as demo:
gr.Markdown("""
## 棋盘检测, 棋子识别
features: 轻量化模型
x 表示 有遮挡位置
. 表示 棋盘上的普通交叉点
步骤:
1. 流程分成两步,第一步 keypoints 检测
2. 拉伸棋盘,并预测棋子
log:
1. 优化棋子识别,增加对游戏棋盘的识别
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")
with gr.Column():
original_image_with_keypoints = gr.Image(
label="step1: 原图带关键点",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Row():
with gr.Column():
transformed_image = gr.Image(
label="step2: 拉伸棋盘",
interactive=False,
visible=True,
elem_classes="image"
)
with gr.Column():
use_time = gr.Textbox(
label="用时",
interactive=False,
visible=True,
)
# 添加 手风琴
with gr.Accordion("文字识别", open=False):
layout_pred_info = gr.Dataframe(
label="棋子识别",
interactive=False,
visible=True,
)
with gr.Accordion("棋子识别", open=True):
# 10 行 9 列的表格
table_html = gr.HTML(
"""
<table>
</table>
"""
)
with gr.Row():
with gr.Column():
gr.Examples(
full_examples[:10], inputs=[image_input], label="示例图片1", examples_per_page=10,)
gr.Examples(
full_examples[10:20], inputs=[image_input], label="示例图片2", examples_per_page=10,)
gr.Examples(
full_examples[20:], inputs=[image_input], label="示例图片3", examples_per_page=10,)
def gen_table_html(annotation_arr_10_9):
# 生成表格 HTML
html = "<table border='1' style='margin: auto;'>"
for row in annotation_arr_10_9:
html += "<tr>"
for cell in row:
if cell == '.':
# 普通交叉点
html += "<td style='width: 60px; height: 60px; text-align: center;'></td>"
elif cell == 'x':
# 遮挡位置
html += "<td style='width: 60px; height: 60px; text-align: center;'>x</td>"
else:
# 获取对应的图片文件名
img_file = dict_cate_images.get(cell, '')
img_data_base64 = get_image_base64(img_file)
# 生成图片标签
html += f"<td style='width: 60px; height: 60px; text-align: center; padding: 0;'><img src='{img_data_base64}' width='58' height='58'></td>"
html += "</tr>"
html += "</table>"
return html
def detect_chessboard(image):
if image is None:
return None, None, None, None, None
try:
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)
# 将 cells_labels 转换为 DataFrame
# cells_labels 通过 \n 分割
annotation_10_rows = [item for item in cells_labels_str.split("\n")]
# 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
annotation_arr_10_9_short = [list(item) for item in annotation_10_rows]
# 将 棋子类别 转换为 中文
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short]
except Exception as e:
gr.Warning(f"检测失败 图片或者视频布局错误")
return None, None, None, None, None
table_html = gen_table_html(annotation_arr_10_9_short)
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info
image_input.change(fn=detect_chessboard,
inputs=[image_input],
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time])
if __name__ == "__main__":
demo.launch()