File size: 7,031 Bytes
37170d6
5873e33
9316eb4
19a60be
 
 
37170d6
 
 
19a60be
 
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a60be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37170d6
 
 
19a60be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9316eb4
 
 
 
 
 
5873e33
9316eb4
5873e33
9316eb4
 
 
 
 
 
 
5873e33
2a190c0
37170d6
 
 
 
 
 
 
ec71a6c
19a60be
ec71a6c
9316eb4
 
 
37170d6
ec71a6c
 
19a60be
 
 
37170d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a60be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37170d6
9316eb4
 
ec71a6c
19a60be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9316eb4
37170d6
 
9316eb4
4bcaf91
9316eb4
 
 
37170d6
9316eb4
 
 
 
19a60be
37170d6
9316eb4
19a60be
9316eb4
 
 
4bcaf91
37170d6
 
19a60be
 
 
37170d6
 
 
19a60be
37170d6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import gradio as gr
# import cv2
import os
import base64
from pathlib import Path
    
from core.chessboard_detector import ChessboardDetector

detector = ChessboardDetector(
    pose_model_path="onnx/pose/4_v6-0301.onnx",
    full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx"
)

# 数据集路径
dict_cate_names = {
    '.': '.',
    'x': 'x',
    '红帅': 'K',
    '红士': 'A',
    '红相': 'B',
    '红马': 'N',
    '红车': 'R',
    '红炮': 'C',
    '红兵': 'P',

    '黑将': 'k',
    '黑仕': 'a',
    '黑象': 'b',
    '黑傌': 'n',
    '黑車': 'r',
    '黑砲': 'c',
    '黑卒': 'p',
}

# 数据集路径
dict_cate_images = {
    'K': 'red_K.png',
    'A': 'red_A.png',
    'B': 'red_B.png',
    'N': 'red_N.png',
    'R': 'red_R.png',
    'C': 'red_C.png',
    'P': 'red_P.png',
    'k': 'black_k.png',
    'a': 'black_a.png',
    'b': 'black_b.png',
    'n': 'black_n.png',
    'r': 'black_r.png',
    'c': 'black_c.png',
    'p': 'black_p.png',
}

dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}


 # 缓存图片的 base64 编码
image_base64_cache = {}

def get_image_base64(img_path):
    if img_path in image_base64_cache:
        return image_base64_cache[img_path]
    
    try:
        img_full_path = Path("resources") / img_path
        if not img_full_path.exists():
            return ""
            
        with open(img_full_path, "rb") as img_file:
            encoded = base64.b64encode(img_file.read()).decode('utf-8')
            data_url = f"data:image/png;base64,{encoded}"
            image_base64_cache[img_path] = data_url
            return data_url
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        return 




### 构建 examples 

def build_examples():
    examples = []
    # 读取 examples 目录下的所有图片
    for file in os.listdir("examples"):
        if file.endswith(".jpg") or file.endswith(".png"):
            image_path = os.path.join("examples", file)
            examples.append([image_path])

    return examples


full_examples = build_examples()


with gr.Blocks(css="""
        .image img {
            max-height: 512px;
        }
    """
) as demo:
    gr.Markdown("""
                ## 棋盘检测, 棋子识别

                features: 轻量化模型
                

                x 表示 有遮挡位置   
                . 表示 棋盘上的普通交叉点  

                步骤:  
                    1. 流程分成两步,第一步 keypoints 检测
                    2. 拉伸棋盘,并预测棋子
                    
                log:
                    1. 优化棋子识别,增加对游戏棋盘的识别
                """
    )
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="上传棋盘图片", type="numpy", elem_classes="image")

        with gr.Column():
            original_image_with_keypoints = gr.Image(
                label="step1: 原图带关键点",
                interactive=False,
                visible=True,
                elem_classes="image"
            )


    with gr.Row():
        with gr.Column():   
            transformed_image = gr.Image(
                label="step2: 拉伸棋盘",
                interactive=False,
                visible=True,
                elem_classes="image"
            )

        with gr.Column():
            use_time = gr.Textbox(
                label="用时",
                interactive=False,
                visible=True,
            )
            
            # 添加 手风琴
            with gr.Accordion("文字识别", open=False):
                layout_pred_info = gr.Dataframe(
                    label="棋子识别",
                    interactive=False,
                    visible=True,
                )
            
            with gr.Accordion("棋子识别", open=True):
                # 10 行 9 列的表格
                table_html = gr.HTML(
                    """
                    <table>
                    </table>
                    """
                )

    with gr.Row():
        with gr.Column():
            gr.Examples(
                full_examples[:10], inputs=[image_input], label="示例图片1",  examples_per_page=10,)
            
            gr.Examples(
                full_examples[10:20], inputs=[image_input], label="示例图片2",  examples_per_page=10,)
            
            gr.Examples(
                full_examples[20:], inputs=[image_input], label="示例图片3",  examples_per_page=10,)
            

    def gen_table_html(annotation_arr_10_9):
        # 生成表格 HTML
        html = "<table border='1' style='margin: auto;'>"
        
        for row in annotation_arr_10_9:
            html += "<tr>"
            for cell in row:
                if cell == '.':
                    # 普通交叉点
                    html += "<td style='width: 60px; height: 60px; text-align: center;'></td>"
                elif cell == 'x':
                    # 遮挡位置
                    html += "<td style='width: 60px; height: 60px; text-align: center;'>x</td>"
                else:
                    # 获取对应的图片文件名
                    img_file = dict_cate_images.get(cell, '')
                    img_data_base64 = get_image_base64(img_file)
                    # 生成图片标签
                    html += f"<td style='width: 60px; height: 60px; text-align: center; padding: 0;'><img src='{img_data_base64}' width='58' height='58'></td>"
            html += "</tr>"
        
        html += "</table>"
        return html
        


    def detect_chessboard(image):
        if image is None:
            return None, None, None, None, None
        
        try:
            original_image_with_keypoints, transformed_image, cells_labels_str, scores, time_info = detector.pred_detect_board_and_classifier(image)

            # 将 cells_labels 转换为 DataFrame
            # cells_labels 通过  \n 分割
            annotation_10_rows = [item for item in cells_labels_str.split("\n")]
            # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
            annotation_arr_10_9_short = [list(item) for item in annotation_10_rows]

            # 将 棋子类别 转换为 中文
            annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short]

        except Exception as e:
            gr.Warning(f"检测失败 图片或者视频布局错误")
            return None, None, None, None, None


        table_html = gen_table_html(annotation_arr_10_9_short)

        return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info

    image_input.change(fn=detect_chessboard, 
                       inputs=[image_input], 
                       outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time])

if __name__ == "__main__":
    demo.launch()