SWHL commited on
Commit
dadc86d
·
verified ·
1 Parent(s): 3c23b0b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +292 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- encoding: utf-8 -*-
2
+ # @Author: SWHL
3
+ # @Contact: [email protected]
4
+ from enum import Enum
5
+ from pathlib import Path
6
+ from typing import List, Union
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from rapidocr import RapidOCR
11
+
12
+
13
+ class InferEngine(Enum):
14
+ ort = "ONNXRuntime"
15
+ vino = "OpenVino"
16
+ paddle = "PaddlePaddle"
17
+ torch = "PyTorch"
18
+
19
+
20
+ def get_ocr_engine(infer_engine: str, lang_det: str, lang_rec: str) -> RapidOCR:
21
+ if infer_engine == InferEngine.vino.value:
22
+ return RapidOCR(
23
+ params={
24
+ "Global.with_openvino": True,
25
+ "Global.lang_det": lang_det,
26
+ "Global.lang_rec": lang_rec,
27
+ }
28
+ )
29
+
30
+ if infer_engine == InferEngine.paddle.value:
31
+ return RapidOCR(
32
+ params={
33
+ "Global.with_paddle": True,
34
+ "Global.lang_det": lang_det,
35
+ "Global.lang_rec": lang_rec,
36
+ }
37
+ )
38
+
39
+ if infer_engine == InferEngine.torch.value:
40
+ return RapidOCR(
41
+ params={
42
+ "Global.with_torch": True,
43
+ "Global.lang_det": lang_det,
44
+ "Global.lang_rec": lang_rec,
45
+ }
46
+ )
47
+
48
+ return RapidOCR(
49
+ params={
50
+ "Global.with_onnx": True,
51
+ "Global.lang_det": lang_det,
52
+ "Global.lang_rec": lang_rec,
53
+ }
54
+ )
55
+
56
+
57
+ def get_ocr_result(
58
+ img: np.ndarray,
59
+ text_score,
60
+ box_thresh,
61
+ unclip_ratio,
62
+ lang_det,
63
+ lang_rec,
64
+ infer_engine,
65
+ is_word: str,
66
+ ):
67
+ return_word_box = True if is_word == "Yes" else False
68
+
69
+ ocr_engine = get_ocr_engine(infer_engine, lang_det=lang_det, lang_rec=lang_rec)
70
+
71
+ ocr_result = ocr_engine(
72
+ img,
73
+ text_score=text_score,
74
+ box_thresh=box_thresh,
75
+ unclip_ratio=unclip_ratio,
76
+ return_word_box=return_word_box,
77
+ )
78
+ vis_img = ocr_result.vis()
79
+ if return_word_box:
80
+ txts, scores, _ = list(zip(*ocr_result.word_results))
81
+ ocr_txts = [[i, txt, score] for i, (txt, score) in enumerate(zip(txts, scores))]
82
+ return vis_img, ocr_txts, ocr_result.elapse
83
+
84
+ ocr_txts = [
85
+ [i, txt, score]
86
+ for i, (txt, score) in enumerate(zip(ocr_result.txts, ocr_result.scores))
87
+ ]
88
+ return vis_img, ocr_txts, ocr_result.elapse
89
+
90
+
91
+ def create_examples() -> List[List[Union[str, float]]]:
92
+ examples = [
93
+ [
94
+ "images/ch_en_num.jpg",
95
+ 0.5,
96
+ 0.5,
97
+ 1.6,
98
+ "ch_mobile",
99
+ "ch_mobile",
100
+ "ONNXRuntime",
101
+ "No",
102
+ ],
103
+ [
104
+ "images/japan.jpg",
105
+ 0.5,
106
+ 0.5,
107
+ 1.6,
108
+ "multi_mobile",
109
+ "japan_mobile",
110
+ "ONNXRuntime",
111
+ "No",
112
+ ],
113
+ [
114
+ "images/korean.jpg",
115
+ 0.5,
116
+ 0.5,
117
+ 1.6,
118
+ "multi_mobile",
119
+ "korean_mobile",
120
+ "ONNXRuntime",
121
+ "No",
122
+ ],
123
+ [
124
+ "images/air_ticket.jpg",
125
+ 0.5,
126
+ 0.5,
127
+ 1.6,
128
+ "ch_mobile",
129
+ "ch_mobile",
130
+ "ONNXRuntime",
131
+ "No",
132
+ ],
133
+ [
134
+ "images/car_plate.jpeg",
135
+ 0.5,
136
+ 0.5,
137
+ 1.6,
138
+ "ch_mobile",
139
+ "ch_mobile",
140
+ "ONNXRuntime",
141
+ "No",
142
+ ],
143
+ [
144
+ "images/train_ticket.jpeg",
145
+ 0.5,
146
+ 0.5,
147
+ 1.6,
148
+ "ch_mobile",
149
+ "ch_mobile",
150
+ "ONNXRuntime",
151
+ "No",
152
+ ],
153
+ ]
154
+ return examples
155
+
156
+
157
+ infer_engine_list = [InferEngine[v].value for v in InferEngine.__members__]
158
+
159
+ lang_det_list = ["ch_mobile", "ch_server", "en_mobile", "en_server", "multi_mobile"]
160
+ lang_rec_list = [
161
+ "ch_mobile",
162
+ "ch_server",
163
+ "chinese_cht",
164
+ "en_mobile",
165
+ "ar_mobile",
166
+ "cyrillic_mobile",
167
+ "devanagari_mobile",
168
+ "japan_mobile",
169
+ "ka_mobile",
170
+ "korean_mobile",
171
+ "latin_mobile",
172
+ "ta_mobile",
173
+ "te_mobile",
174
+ ]
175
+
176
+ custom_css = """
177
+ body {font-family: body {font-family: 'Helvetica Neue', Helvetica;}
178
+ .gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;}
179
+ .gr-button:hover {background-color: #45a049;}
180
+ .gr-textbox {margin-bottom: 15px;}
181
+ .example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;}
182
+ .example-button:hover {background-color: #FF4500;}
183
+ .tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;}
184
+ .tall-radio label {font-size: 16px;}
185
+ .output-image, .input-image, .image-preview {height: 300px !important}
186
+ """
187
+
188
+ with gr.Blocks(
189
+ title="Rapid⚡OCR Demo", css="custom_css", theme=gr.themes.Soft()
190
+ ) as demo:
191
+ gr.Markdown(
192
+ "<h1 style='text-align: center;'><a href='https://rapidai.github.io/RapidOCRDocs/' style='text-decoration: none;'>Rapid⚡OCR</a></h1>"
193
+ )
194
+ gr.HTML(
195
+ """
196
+ <div style="display: flex; justify-content: center; gap: 10px;">
197
+ <a href=""><img src="https://img.shields.io/badge/Python->=3.6-aff.svg"></a>
198
+ <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
199
+ <a href="https://pepy.tech/project/rapidocr"><img src="https://static.pepy.tech/personalized-badge/rapidocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20rapidocr"></a>
200
+ <a href="https://pypi.org/project/rapidocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr"></a>
201
+ <a href="https://github.com/RapidAI/RapidOCR"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a>
202
+ </div>
203
+ """
204
+ )
205
+ with gr.Row():
206
+ text_score = gr.Slider(
207
+ label="text_score",
208
+ minimum=0,
209
+ maximum=1.0,
210
+ value=0.5,
211
+ step=0.1,
212
+ info="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
213
+ )
214
+ box_thresh = gr.Slider(
215
+ label="box_thresh",
216
+ minimum=0,
217
+ maximum=1.0,
218
+ value=0.5,
219
+ step=0.1,
220
+ info="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
221
+ )
222
+ unclip_ratio = gr.Slider(
223
+ label="unclip_ratio",
224
+ minimum=1.5,
225
+ maximum=2.0,
226
+ value=1.6,
227
+ step=0.1,
228
+ info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
229
+ )
230
+
231
+ with gr.Row():
232
+ select_infer_engine = gr.Dropdown(
233
+ choices=infer_engine_list,
234
+ label="Infer Engine (推理引擎)",
235
+ value="ONNXRuntime",
236
+ interactive=True,
237
+ )
238
+ lang_det = gr.Dropdown(
239
+ choices=lang_det_list,
240
+ label="Det model (文本检测模型)",
241
+ value=lang_det_list[0],
242
+ interactive=True,
243
+ )
244
+ lang_rec = gr.Dropdown(
245
+ choices=lang_rec_list,
246
+ label="Rec model (文本识别模型)",
247
+ value=lang_rec_list[0],
248
+ interactive=True,
249
+ )
250
+ is_word = gr.Radio(
251
+ ["Yes", "No"], label="Return word box (返回单字符)", value="No"
252
+ )
253
+
254
+ img_input = gr.Image(label="Upload or Select Image", sources="upload")
255
+
256
+ run_btn = gr.Button("Run")
257
+
258
+ img_output = gr.Image(label="Output Image")
259
+ elapse = gr.Textbox(label="Elapse(s)")
260
+ ocr_results = gr.Dataframe(
261
+ label="OCR Txts",
262
+ headers=["Index", "Txt", "Score"],
263
+ datatype=["number", "str", "number"],
264
+ show_copy_button=True,
265
+ )
266
+
267
+ ocr_inputs = [
268
+ img_input,
269
+ text_score,
270
+ box_thresh,
271
+ unclip_ratio,
272
+ lang_det,
273
+ lang_rec,
274
+ select_infer_engine,
275
+ is_word,
276
+ ]
277
+ run_btn.click(
278
+ get_ocr_result, inputs=ocr_inputs, outputs=[img_output, ocr_results, elapse]
279
+ )
280
+
281
+ examples = gr.Examples(
282
+ examples=create_examples(),
283
+ examples_per_page=5,
284
+ inputs=ocr_inputs,
285
+ fn=get_ocr_result,
286
+ outputs=[img_output, ocr_results, elapse],
287
+ cache_examples=False,
288
+ )
289
+
290
+
291
+ if __name__ == "__main__":
292
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ rapidocr
2
+ torch
3
+ onnxruntime
4
+ paddlepaddle
5
+ openvino