yolo12138 commited on
Commit
19a60be
·
1 Parent(s): 10cae01
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import gradio as gr
2
  # import cv2
3
  import os
 
 
 
4
  from core.chessboard_detector import ChessboardDetector
5
 
6
  detector = ChessboardDetector(
7
- pose_model_path="onnx/pose/4_v3.onnx",
8
- full_classifier_model_path="onnx/layout_recognition/nano_v1.onnx"
9
  )
10
 
11
  # 数据集路径
@@ -29,9 +32,51 @@ dict_cate_names = {
29
  '黑卒': 'p',
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ### 构建 examples
36
 
37
  def build_examples():
@@ -58,6 +103,7 @@ with gr.Blocks(css="""
58
  ## 棋盘检测, 棋子识别
59
 
60
  features: 轻量化模型
 
61
 
62
  x 表示 有遮挡位置
63
  . 表示 棋盘上的普通交叉点
@@ -65,6 +111,9 @@ with gr.Blocks(css="""
65
  步骤:
66
  1. 流程分成两步,第一步 keypoints 检测
67
  2. 拉伸棋盘,并预测棋子
 
 
 
68
  """
69
  )
70
  with gr.Row():
@@ -95,16 +144,60 @@ with gr.Blocks(css="""
95
  interactive=False,
96
  visible=True,
97
  )
98
- layout_pred_info = gr.Dataframe(
99
- label="棋子识别",
100
- interactive=False,
101
- visible=True,
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  with gr.Row():
105
  with gr.Column():
106
  gr.Examples(
107
- full_examples, inputs=[image_input], label="示例图片", examples_per_page=15,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  def detect_chessboard(image):
@@ -118,21 +211,23 @@ with gr.Blocks(css="""
118
  # cells_labels 通过 \n 分割
119
  annotation_10_rows = [item for item in cells_labels_str.split("\n")]
120
  # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
121
- annotation_arr_10_9 = [list(item) for item in annotation_10_rows]
122
 
123
  # 将 棋子类别 转换为 中文
124
- annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9]
125
 
126
  except Exception as e:
127
  gr.Warning(f"检测失败 图片或者视频布局错误")
128
  return None, None, None, None
129
 
130
 
131
- return original_image_with_keypoints, transformed_image, annotation_arr_10_9, time_info
 
 
132
 
133
  image_input.change(fn=detect_chessboard,
134
  inputs=[image_input],
135
- outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, use_time])
136
 
137
  if __name__ == "__main__":
138
  demo.launch()
 
1
  import gradio as gr
2
  # import cv2
3
  import os
4
+ import base64
5
+ from pathlib import Path
6
+
7
  from core.chessboard_detector import ChessboardDetector
8
 
9
  detector = ChessboardDetector(
10
+ pose_model_path="onnx/pose/4_v6-0301.onnx",
11
+ full_classifier_model_path="onnx/layout_recognition/nano_v3-0319.onnx"
12
  )
13
 
14
  # 数据集路径
 
32
  '黑卒': 'p',
33
  }
34
 
35
+ # 数据集路径
36
+ dict_cate_images = {
37
+ 'K': 'red_K.png',
38
+ 'A': 'red_A.png',
39
+ 'B': 'red_B.png',
40
+ 'N': 'red_N.png',
41
+ 'R': 'red_R.png',
42
+ 'C': 'red_C.png',
43
+ 'P': 'red_P.png',
44
+ 'k': 'black_k.png',
45
+ 'a': 'black_a.png',
46
+ 'b': 'black_b.png',
47
+ 'n': 'black_n.png',
48
+ 'r': 'black_r.png',
49
+ 'c': 'black_c.png',
50
+ 'p': 'black_p.png',
51
+ }
52
+
53
  dict_cate_names_reverse = {v: k for k, v in dict_cate_names.items()}
54
 
55
 
56
+ # 缓存图片的 base64 编码
57
+ image_base64_cache = {}
58
+
59
+ def get_image_base64(img_path):
60
+ if img_path in image_base64_cache:
61
+ return image_base64_cache[img_path]
62
+
63
+ try:
64
+ img_full_path = Path("resources") / img_path
65
+ if not img_full_path.exists():
66
+ return ""
67
+
68
+ with open(img_full_path, "rb") as img_file:
69
+ encoded = base64.b64encode(img_file.read()).decode('utf-8')
70
+ data_url = f"data:image/png;base64,{encoded}"
71
+ image_base64_cache[img_path] = data_url
72
+ return data_url
73
+ except Exception as e:
74
+ print(f"Error loading image {img_path}: {e}")
75
+ return
76
+
77
+
78
+
79
+
80
  ### 构建 examples
81
 
82
  def build_examples():
 
103
  ## 棋盘检测, 棋子识别
104
 
105
  features: 轻量化模型
106
+
107
 
108
  x 表示 有遮挡位置
109
  . 表示 棋盘上的普通交叉点
 
111
  步骤:
112
  1. 流程分成两步,第一步 keypoints 检测
113
  2. 拉伸棋盘,并预测棋子
114
+
115
+ log:
116
+ 1. 优化棋子识别,增加对游戏棋盘的识别
117
  """
118
  )
119
  with gr.Row():
 
144
  interactive=False,
145
  visible=True,
146
  )
147
+
148
+ # 添加 手风琴
149
+ with gr.Accordion("文字识别", open=False):
150
+ layout_pred_info = gr.Dataframe(
151
+ label="棋子识别",
152
+ interactive=False,
153
+ visible=True,
154
+ )
155
+
156
+ with gr.Accordion("棋子识别", open=True):
157
+ # 10 行 9 列的表格
158
+ table_html = gr.HTML(
159
+ """
160
+ <table>
161
+ </table>
162
+ """
163
+ )
164
 
165
  with gr.Row():
166
  with gr.Column():
167
  gr.Examples(
168
+ full_examples[:10], inputs=[image_input], label="示例图片1", examples_per_page=10,)
169
+
170
+ gr.Examples(
171
+ full_examples[10:20], inputs=[image_input], label="示例图片2", examples_per_page=10,)
172
+
173
+ gr.Examples(
174
+ full_examples[20:], inputs=[image_input], label="示例图片3", examples_per_page=10,)
175
+
176
+
177
+ def gen_table_html(annotation_arr_10_9):
178
+ # 生成表格 HTML
179
+ html = "<table border='1' style='margin: auto;'>"
180
+
181
+ for row in annotation_arr_10_9:
182
+ html += "<tr>"
183
+ for cell in row:
184
+ if cell == '.':
185
+ # 普通交叉点
186
+ html += "<td style='width: 60px; height: 60px; text-align: center;'></td>"
187
+ elif cell == 'x':
188
+ # 遮挡位置
189
+ html += "<td style='width: 60px; height: 60px; text-align: center;'>x</td>"
190
+ else:
191
+ # 获取对应的图片文件名
192
+ img_file = dict_cate_images.get(cell, '')
193
+ img_data_base64 = get_image_base64(img_file)
194
+ # 生成图片标签
195
+ html += f"<td style='width: 60px; height: 60px; text-align: center; padding: 0;'><img src='{img_data_base64}' width='58' height='58'></td>"
196
+ html += "</tr>"
197
+
198
+ html += "</table>"
199
+ return html
200
+
201
 
202
 
203
  def detect_chessboard(image):
 
211
  # cells_labels 通过 \n 分割
212
  annotation_10_rows = [item for item in cells_labels_str.split("\n")]
213
  # 将 annotation_10_rows 转换成为 10 行 9 列的二维数组
214
+ annotation_arr_10_9_short = [list(item) for item in annotation_10_rows]
215
 
216
  # 将 棋子类别 转换为 中文
217
+ annotation_arr_10_9 = [[dict_cate_names_reverse[item] for item in row] for row in annotation_arr_10_9_short]
218
 
219
  except Exception as e:
220
  gr.Warning(f"检测失败 图片或者视频布局错误")
221
  return None, None, None, None
222
 
223
 
224
+ table_html = gen_table_html(annotation_arr_10_9_short)
225
+
226
+ return original_image_with_keypoints, transformed_image, annotation_arr_10_9, table_html, time_info
227
 
228
  image_input.change(fn=detect_chessboard,
229
  inputs=[image_input],
230
+ outputs=[original_image_with_keypoints, transformed_image, layout_pred_info, table_html, use_time])
231
 
232
  if __name__ == "__main__":
233
  demo.launch()
onnx/layout_recognition/nano_v3-0319.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da66ba9809f15127f8ae729b1755e42ee61c100c4f9979ce0ef13602ac471298
3
+ size 31101356
resources/black_a.png ADDED
resources/black_b.png ADDED
resources/black_c.png ADDED
resources/black_k.png ADDED
resources/black_n.png ADDED
resources/black_p.png ADDED
resources/black_r.png ADDED
resources/red_A.png ADDED
resources/red_B.png ADDED
resources/red_C.png ADDED
resources/red_K.png ADDED
resources/red_N.png ADDED
resources/red_P.png ADDED
resources/red_R.png ADDED