|
import gradio as gr |
|
|
|
import os |
|
import base64 |
|
from pathlib import Path |
|
|
|
from core.chessboard_detector import ChessboardDetector |
|
|
|
detector = ChessboardDetector( |
|
pose_model_path="onnx/pose/4_v6-0301.onnx", |
|
full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx" |
|
) |
|
|
|
|
|
dict_cate_names = { |
|
'.': '.', |
|
'x': 'x', |
|
'红帅': 'K', |
|
'红士': 'A', |
|
'红相': 'B', |
|
'红马': 'N', |
|
'红车': 'R', |
|
'红炮': 'C', |
|
'红兵': 'P', |
|
|
|
'黑将': 'k', |
|
'黑仕': 'a', |
|
'黑象': 'b', |
|
'黑傌': 'n', |
|
'黑車': 'r', |
|
'黑砲': 'c', |
|
'黑卒': 'p', |
|
} |
|
|
|
|
|
dict_cate_images = { |
|
'K': 'red_K.png', |
|
'A': 'red_A.png', |
|
'B': 'red_B.png', |
|
'N': 'red_N.png', |
|
'R': 'red_R.png', |
|
'C': 'red_C.png', |
|
'P': 'red_P.png', |
|
'k': 'black_k.png', |
|
'a': 'black_a.png', |
|
'b': 'black_b.png', |
|
'n': 'black_n.png', |
|
'r': 'black_r.png', |
|
'c': 'black_c.png', |
|
'p': 'black_p.png', |
|
} |
|
|
|
dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()} |
|
|
|
|
|
|
|
image_base64_cache = {} |
|
|
|
def get_image_base64(img_path): |
|
if img_path in image_base64_cache: |
|
return image_base64_cache[img_path] |
|
|
|
try: |
|
img_full_path = Path("resources") / img_path |
|
if not img_full_path.exists(): |
|
return "" |
|
|
|
with open(img_full_path, "rb") as img_file: |
|
encoded = base64.b64encode(img_file.read()).decode('utf-8') |
|
data_url = f"data:image/png;base64,{encoded}" |
|
image_base64_cache[img_path] = data_url |
|
return data_url |
|
except Exception as e: |
|
print(f"Error loading image {img_path}: {e}") |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_examples(): |
|
examples = [] |
|
|
|
for file in os.listdir("examples"): |
|
if file.endswith(".jpg") or file.endswith(".png"): |
|
image_path = os.path.join("examples", file) |
|
examples.append([image_path]) |
|
|
|
return examples |
|
|
|
|
|
full_examples = build_examples() |
|
|
|
|
|
with gr.Blocks(css=""" |
|
.image img { |
|
max-height: 512px; |
|
} |
|
""" |
|
) as demo: |
|
gr.Markdown(""" |
|
## 棋盘检测, 棋子识别 |
|
|
|
features: 轻量化模型 |
|
|
|
|
|
x 表示 有遮挡位置 |
|
. 表示 棋盘上的普通交叉点 |
|
|
|
步骤: |
|
1. 流程分成两步,第一步 keypoints 检测 |
|
2. 拉伸棋盘,并预测棋子 |
|
|
|
log: |
|
1. 优化棋子识别,增加对游戏棋盘的识别 |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image") |
|
|
|
with gr.Column(): |
|
original_image_with_keypoints = gr.Image( |
|
label="step1: 原图带关键点", |
|
interactive=False, |
|
visible=True, |
|
elem_classes="image" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
transformed_image = gr.Image( |
|
label="step2: 拉伸棋盘", |
|
interactive=False, |
|
visible=True, |
|
elem_classes="image" |
|
) |
|
|
|
with gr.Column(): |
|
use_time = gr.Textbox( |
|
label="用时", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
|
|
with gr.Accordion("文字识别", open=False): |
|
layout_pred_info = gr.Dataframe( |
|
label="棋子识别", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
with gr.Accordion("棋子识别", open=True): |
|
|
|
table_html = gr.HTML( |
|
""" |
|
<table> |
|
</table> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Examples( |
|
full_examples[:10], inputs=[image_input], label="示例图片1", examples_per_page=10,) |
|
|
|
gr.Examples( |
|
full_examples[10:20], inputs=[image_input], label="示例图片2", examples_per_page=10,) |
|
|
|
gr.Examples( |
|
full_examples[20:], inputs=[image_input], label="示例图片3", examples_per_page=10,) |
|
|
|
|
|
def gen_table_html(annotation_arr_10_9): |
|
|
|
html = "<table border='1' style='margin: auto;'>" |
|
|
|
for row in annotation_arr_10_9: |
|
html += "<tr>" |
|
for cell in row: |
|
if cell == '.': |
|
|
|
html += "<td style='width: 60px; height: 60px; text-align: center;'></td>" |
|
elif cell == 'x': |
|
|
|
html += "<td style='width: 60px; height: 60px; text-align: center;'>x</td>" |
|
else: |
|
|
|
img_file = dict_cate_images.get(cell, '') |
|
img_data_base64 = get_image_base64(img_file) |
|
|
|
html += f"<td style='width: 60px; height: 60px; text-align: center; padding: 0;'><img src='{img_data_base64}' width='58' height='58'></td>" |
|
html += "</tr>" |
|
|
|
html += "</table>" |
|
return html |
|
|
|
|
|
|
|
def detect_chessboard(image): |
|
if image is None: |
|
return None, None, None, None, None |
|
|
|
try: |
|
original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image) |
|
|
|
|
|
|
|
annotation_10_rows = [item for item in cells_labels_str.split("\n")] |
|
|
|
annotation_arr_10_9_short = [list(item) for item in annotation_10_rows] |
|
|
|
|
|
annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short] |
|
|
|
except Exception as e: |
|
gr.Warning(f"检测失败 图片或者视频布局错误") |
|
return None, None, None, None, None |
|
|
|
|
|
table_html = gen_table_html(annotation_arr_10_9_short) |
|
|
|
return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info |
|
|
|
image_input.change(fn=detect_chessboard, |
|
inputs=[image_input], |
|
outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|