import gradio as gr # import cv2 import os import base64 from pathlib import Path from core.chessboard_detector import ChessboardDetector detector = ChessboardDetector( pose_model_path="onnx/pose/4_v6-0301.onnx", full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx" ) # 数据集路径 dict_cate_names = { '.': '.', 'x': 'x', '红帅': 'K', '红士': 'A', '红相': 'B', '红马': 'N', '红车': 'R', '红炮': 'C', '红兵': 'P', '黑将': 'k', '黑仕': 'a', '黑象': 'b', '黑傌': 'n', '黑車': 'r', '黑砲': 'c', '黑卒': 'p', } # 数据集路径 dict_cate_images = { 'K': 'red_K.png', 'A': 'red_A.png', 'B': 'red_B.png', 'N': 'red_N.png', 'R': 'red_R.png', 'C': 'red_C.png', 'P': 'red_P.png', 'k': 'black_k.png', 'a': 'black_a.png', 'b': 'black_b.png', 'n': 'black_n.png', 'r': 'black_r.png', 'c': 'black_c.png', 'p': 'black_p.png', } dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()} # 缓存图片的 base64 编码 image_base64_cache = {} def get_image_base64(img_path): if img_path in image_base64_cache: return image_base64_cache[img_path] try: img_full_path = Path("resources") / img_path if not img_full_path.exists(): return "" with open(img_full_path, "rb") as img_file: encoded = base64.b64encode(img_file.read()).decode('utf-8') data_url = f"data:image/png;base64,{encoded}" image_base64_cache[img_path] = data_url return data_url except Exception as e: print(f"Error loading image {img_path}: {e}") return ### 构建 examples def build_examples(): examples = [] # 读取 examples 目录下的所有图片 for file in os.listdir("examples"): if file.endswith(".jpg") or file.endswith(".png"): image_path = os.path.join("examples", file) examples.append([image_path]) return examples full_examples = build_examples() with gr.Blocks(css=""" .image img { max-height: 512px; } """ ) as demo: gr.Markdown(""" ## 棋盘检测, 棋子识别 features: 轻量化模型 x 表示 有遮挡位置 . 表示 棋盘上的普通交叉点 步骤: 1. 流程分成两步,第一步 keypoints 检测 2. 拉伸棋盘,并预测棋子 log: 1. 优化棋子识别,增加对游戏棋盘的识别 """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image") with gr.Column(): original_image_with_keypoints = gr.Image( label="step1: 原图带关键点", interactive=False, visible=True, elem_classes="image" ) with gr.Row(): with gr.Column(): transformed_image = gr.Image( label="step2: 拉伸棋盘", interactive=False, visible=True, elem_classes="image" ) with gr.Column(): use_time = gr.Textbox( label="用时", interactive=False, visible=True, ) # 添加 手风琴 with gr.Accordion("文字识别", open=False): layout_pred_info = gr.Dataframe( label="棋子识别", interactive=False, visible=True, ) with gr.Accordion("棋子识别", open=True): # 10 行 9 列的表格 table_html = gr.HTML( """
""" ) with gr.Row(): with gr.Column(): gr.Examples( full_examples[:10], inputs=[image_input], label="示例图片1", examples_per_page=10,) gr.Examples( full_examples[10:20], inputs=[image_input], label="示例图片2", examples_per_page=10,) gr.Examples( full_examples[20:], inputs=[image_input], label="示例图片3", examples_per_page=10,) def gen_table_html(annotation_arr_10_9): # 生成表格 HTML html = "" for row in annotation_arr_10_9: html += "" for cell in row: if cell == '.': # 普通交叉点 html += "" elif cell == 'x': # 遮挡位置 html += "" else: # 获取对应的图片文件名 img_file = dict_cate_images.get(cell, '') img_data_base64 = get_image_base64(img_file) # 生成图片标签 html += f"" html += "" html += "
x
" return html def detect_chessboard(image): if image is None: return None, None, None, None, None try: original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image) # 将 cells_labels 转换为 DataFrame # cells_labels 通过 \n 分割 annotation_10_rows = [item for item in cells_labels_str.split("\n")] # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组 annotation_arr_10_9_short = [list(item) for item in annotation_10_rows] # 将 棋子类别 转换为 中文 annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short] except Exception as e: gr.Warning(f"检测失败 图片或者视频布局错误") return None, None, None, None, None table_html = gen_table_html(annotation_arr_10_9_short) return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info image_input.change(fn=detect_chessboard, inputs=[image_input], outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time]) if __name__ == "__main__": demo.launch()