SWHL commited on
Commit
fab3f7c
·
verified ·
1 Parent(s): f6bf2a8

Rename app_gradio.py to app.py

Browse files
Files changed (1) hide show
  1. app_gradio.py → app.py +215 -18
app_gradio.py → app.py RENAMED
@@ -1,24 +1,176 @@
1
  # -*- encoding: utf-8 -*-
2
  # @Author: SWHL
3
  # @Contact: [email protected]
 
 
 
 
4
  import gradio as gr
5
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- img_path = "images/1.jpg"
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def test():
11
- return Image.open(img_path)
12
 
 
13
 
14
- example_images = [
15
- "images/1.jpg",
16
- "images/ch_en_num.jpg",
17
- "images/air_ticket.jpg",
18
- "images/car_plate.jpeg",
19
- "images/train_ticket.jpeg",
20
- "images/japan_2.jpg",
21
- "images/korean_1.jpg",
 
 
 
 
 
 
 
22
  ]
23
 
24
  custom_css = """
@@ -76,17 +228,62 @@ with gr.Blocks(
76
  info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  img_input = gr.Image(label="Upload or Select Image", sources="upload")
 
80
  run_btn = gr.Button("Run")
81
 
82
- run_btn.click(test, inputs=img_input, outputs=gr.Image())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  examples = gr.Examples(
85
- examples=example_images,
86
- examples_per_page=len(example_images),
87
- inputs=img_input,
88
- fn=lambda x: x, # 简单返回图片路径
89
- outputs=img_input,
90
  cache_examples=False,
91
  )
92
 
 
1
  # -*- encoding: utf-8 -*-
2
  # @Author: SWHL
3
  # @Contact: [email protected]
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import List, Union
7
+
8
  import gradio as gr
9
+ import numpy as np
10
+ from rapidocr import RapidOCR
11
+
12
+
13
+ class InferEngine(Enum):
14
+ ort = "ONNXRuntime"
15
+ vino = "OpenVino"
16
+ paddle = "PaddlePaddle"
17
+ torch = "PyTorch"
18
+
19
+
20
+ def get_ocr_engine(infer_engine: str, lang_det: str, lang_rec: str) -> RapidOCR:
21
+ if infer_engine == InferEngine.vino.value:
22
+ return RapidOCR(
23
+ params={
24
+ "Global.with_openvino": True,
25
+ "Global.lang_det": lang_det,
26
+ "Global.lang_rec": lang_rec,
27
+ }
28
+ )
29
+
30
+ if infer_engine == InferEngine.paddle.value:
31
+ return RapidOCR(
32
+ params={
33
+ "Global.with_paddle": True,
34
+ "Global.lang_det": lang_det,
35
+ "Global.lang_rec": lang_rec,
36
+ }
37
+ )
38
+
39
+ if infer_engine == InferEngine.torch.value:
40
+ return RapidOCR(
41
+ params={
42
+ "Global.with_torch": True,
43
+ "Global.lang_det": lang_det,
44
+ "Global.lang_rec": lang_rec,
45
+ }
46
+ )
47
+
48
+ return RapidOCR(
49
+ params={
50
+ "Global.with_onnx": True,
51
+ "Global.lang_det": lang_det,
52
+ "Global.lang_rec": lang_rec,
53
+ }
54
+ )
55
+
56
+
57
+ def get_ocr_result(
58
+ img: np.ndarray,
59
+ text_score,
60
+ box_thresh,
61
+ unclip_ratio,
62
+ lang_det,
63
+ lang_rec,
64
+ infer_engine,
65
+ is_word: str,
66
+ ):
67
+ return_word_box = True if is_word == "Yes" else False
68
 
69
+ ocr_engine = get_ocr_engine(infer_engine, lang_det=lang_det, lang_rec=lang_rec)
70
 
71
+ ocr_result = ocr_engine(
72
+ img,
73
+ text_score=text_score,
74
+ box_thresh=box_thresh,
75
+ unclip_ratio=unclip_ratio,
76
+ return_word_box=return_word_box,
77
+ )
78
+ vis_img = ocr_result.vis()
79
+ if return_word_box:
80
+ txts, scores, _ = list(zip(*ocr_result.word_results))
81
+ ocr_txts = [[i, txt, score] for i, (txt, score) in enumerate(zip(txts, scores))]
82
+ return vis_img, ocr_txts, ocr_result.elapse
83
+
84
+ ocr_txts = [
85
+ [i, txt, score]
86
+ for i, (txt, score) in enumerate(zip(ocr_result.txts, ocr_result.scores))
87
+ ]
88
+ return vis_img, ocr_txts, ocr_result.elapse
89
+
90
+
91
+ def create_examples() -> List[List[Union[str, float]]]:
92
+ examples = [
93
+ [
94
+ "images/ch_en_num.jpg",
95
+ 0.5,
96
+ 0.5,
97
+ 1.6,
98
+ "ch_mobile",
99
+ "ch_mobile",
100
+ "ONNXRuntime",
101
+ "No",
102
+ ],
103
+ [
104
+ "images/japan.jpg",
105
+ 0.5,
106
+ 0.5,
107
+ 1.6,
108
+ "multi_mobile",
109
+ "japan_mobile",
110
+ "ONNXRuntime",
111
+ "No",
112
+ ],
113
+ [
114
+ "images/korean.jpg",
115
+ 0.5,
116
+ 0.5,
117
+ 1.6,
118
+ "multi_mobile",
119
+ "korean_mobile",
120
+ "ONNXRuntime",
121
+ "No",
122
+ ],
123
+ [
124
+ "images/air_ticket.jpg",
125
+ 0.5,
126
+ 0.5,
127
+ 1.6,
128
+ "ch_mobile",
129
+ "ch_mobile",
130
+ "ONNXRuntime",
131
+ "No",
132
+ ],
133
+ [
134
+ "images/car_plate.jpeg",
135
+ 0.5,
136
+ 0.5,
137
+ 1.6,
138
+ "ch_mobile",
139
+ "ch_mobile",
140
+ "ONNXRuntime",
141
+ "No",
142
+ ],
143
+ [
144
+ "images/train_ticket.jpeg",
145
+ 0.5,
146
+ 0.5,
147
+ 1.6,
148
+ "ch_mobile",
149
+ "ch_mobile",
150
+ "ONNXRuntime",
151
+ "No",
152
+ ],
153
+ ]
154
+ return examples
155
 
 
 
156
 
157
+ infer_engine_list = [InferEngine[v].value for v in InferEngine.__members__]
158
 
159
+ lang_det_list = ["ch_mobile", "ch_server", "en_mobile", "en_server", "multi_mobile"]
160
+ lang_rec_list = [
161
+ "ch_mobile",
162
+ "ch_server",
163
+ "chinese_cht",
164
+ "en_mobile",
165
+ "ar_mobile",
166
+ "cyrillic_mobile",
167
+ "devanagari_mobile",
168
+ "japan_mobile",
169
+ "ka_mobile",
170
+ "korean_mobile",
171
+ "latin_mobile",
172
+ "ta_mobile",
173
+ "te_mobile",
174
  ]
175
 
176
  custom_css = """
 
228
  info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
229
  )
230
 
231
+ with gr.Row():
232
+ select_infer_engine = gr.Dropdown(
233
+ choices=infer_engine_list,
234
+ label="Infer Engine (推理引擎)",
235
+ value="ONNXRuntime",
236
+ interactive=True,
237
+ )
238
+ lang_det = gr.Dropdown(
239
+ choices=lang_det_list,
240
+ label="Det model (文本检测模型)",
241
+ value=lang_det_list[0],
242
+ interactive=True,
243
+ )
244
+ lang_rec = gr.Dropdown(
245
+ choices=lang_rec_list,
246
+ label="Rec model (文本识别模型)",
247
+ value=lang_rec_list[0],
248
+ interactive=True,
249
+ )
250
+ is_word = gr.Radio(
251
+ ["Yes", "No"], label="Return word box (返回单字符)", value="No"
252
+ )
253
+
254
  img_input = gr.Image(label="Upload or Select Image", sources="upload")
255
+
256
  run_btn = gr.Button("Run")
257
 
258
+ img_output = gr.Image(label="Output Image")
259
+ elapse = gr.Textbox(label="Elapse(s)")
260
+ ocr_results = gr.Dataframe(
261
+ label="OCR Txts",
262
+ headers=["Index", "Txt", "Score"],
263
+ datatype=["number", "str", "number"],
264
+ show_copy_button=True,
265
+ )
266
+
267
+ ocr_inputs = [
268
+ img_input,
269
+ text_score,
270
+ box_thresh,
271
+ unclip_ratio,
272
+ lang_det,
273
+ lang_rec,
274
+ select_infer_engine,
275
+ is_word,
276
+ ]
277
+ run_btn.click(
278
+ get_ocr_result, inputs=ocr_inputs, outputs=[img_output, ocr_results, elapse]
279
+ )
280
 
281
  examples = gr.Examples(
282
+ examples=create_examples(),
283
+ examples_per_page=5,
284
+ inputs=ocr_inputs,
285
+ fn=get_ocr_result,
286
+ outputs=[img_output, ocr_results, elapse],
287
  cache_examples=False,
288
  )
289