Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
import json | |
import time | |
from pathlib import Path | |
import cv2 | |
import gradio as gr | |
from rapidocr_onnxruntime import RapidOCR | |
from utils import visualize | |
font_dict = { | |
'ch': 'FZYTK.TTF', | |
'japan': 'japan.ttc', | |
'korean': 'korean.ttf', | |
'en': 'FZYTK.TTF', | |
} | |
def inference(img_path, box_thresh=0.5, unclip_ratio=1.6, text_score=0.5, | |
text_det=None, text_rec=None): | |
out_log_list = [] | |
det_model_path = str(Path('models') / 'text_det' / text_det) | |
rec_model_path = str(Path('models') / 'text_rec' / text_rec) | |
if 'v2' in rec_model_path: | |
rec_image_shape = [3, 32, 320] | |
else: | |
rec_image_shape = [3, 48, 320] | |
out_log_list.append('Init Model') | |
s = time.time() | |
rapid_ocr = RapidOCR(det_model_path=det_model_path, | |
rec_model_path=rec_model_path, | |
rec_img_shape=rec_image_shape) | |
elapse = time.time() - s | |
if 'ch' in rec_model_path or 'en' in rec_model_path: | |
lan_name = 'ch' | |
elif 'japan' in rec_model_path: | |
lan_name = 'japan' | |
elif 'korean' in rec_model_path: | |
lan_name = 'korean' | |
else: | |
lan_name = 'ch' | |
out_log_list.append(f'Init Model cost: {elapse:.5f}') | |
out_log_list.extend([f'det_model: {det_model_path}', | |
f'rec_model: {rec_model_path}', | |
f'rec_image_shape: {rec_image_shape}']) | |
img = cv2.imread(img_path) | |
ocr_result, infer_elapse = rapid_ocr(img, box_thresh=box_thresh, | |
unclip_ratio=unclip_ratio, | |
text_score=text_score) | |
det_cost, cls_cost, rec_cost = infer_elapse | |
out_log_list.extend([f'det cost: {det_cost:.5f}', | |
f'cls cost: {cls_cost:.5f}', | |
f'rec cost: {rec_cost:.5f}']) | |
out_log = '\n'.join([str(v) for v in out_log_list]) | |
if not ocr_result: | |
return img_path, '未识别到有效文本', out_log | |
dt_boxes, rec_res, scores = list(zip(*ocr_result)) | |
font_path = Path('fonts') / font_dict.get(lan_name) | |
img_save_path = visualize(img_path, dt_boxes, rec_res, scores, | |
font_path=str(font_path)) | |
out_dict = {str(i): {'rec_txt': rec, 'score': score} | |
for i, (rec, score) in enumerate(zip(rec_res, scores))} | |
return img_save_path, out_dict, out_log | |
if __name__ == '__main__': | |
examples = [['images/1.jpg'], | |
['images/ch_en_num.jpg'], | |
['images/air_ticket.jpg'], | |
['images/car_plate.jpeg'], | |
['images/idcard.jpg'], | |
['images/train_ticket.jpeg'], | |
['images/japan_2.jpg'], | |
['images/korean_1.jpg']] | |
with gr.Blocks(title='RapidOCR') as demo: | |
gr.Markdown(""" | |
<h1><center><a href="https://github.com/RapidAI/RapidOCR" target="_blank">Rapid⚡OCR</a></center></h1> | |
### Docs: [Docs](https://rapidocr.rtfd.io/) | |
### 运行环境: | |
Python: 3.8 | onnxruntime: 1.14.1 | rapidocr_onnxruntime: 1.2.5""") | |
gr.Markdown( | |
'''**[超参数调节](https://github.com/RapidAI/RapidOCR/tree/main/python#configyaml%E4%B8%AD%E5%B8%B8%E7%94%A8%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D)** | |
- **box_thresh**: 检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0] | |
- **unclip_ratio**: 控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0] | |
- **text_score**: 文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0] | |
''') | |
with gr.Row(): | |
box_thresh = gr.Slider(minimum=0, maximum=1.0, value=0.5, | |
label='box_thresh', step=0.1, | |
interactive=True, | |
info='[0, 1.0]') | |
unclip_ratio = gr.Slider(minimum=1.5, maximum=2.0, value=1.6, | |
label='unclip_ratio', step=0.1, | |
interactive=True, | |
info='[1.5, 2.0]') | |
text_score = gr.Slider(minimum=0, maximum=1.0, value=0.5, | |
label='text_score', step=0.1, | |
interactive=True, | |
info='[0, 1.0]') | |
gr.Markdown('**[模型选择](https://github.com/RapidAI/RapidOCR/blob/main/docs/models.md)** (模型转换→[PaddleOCRModelConverter](https://github.com/RapidAI/PaddleOCRModelConverter))') | |
with gr.Row(): | |
text_det = gr.Dropdown(['ch_PP-OCRv3_det_infer.onnx', | |
'ch_PP-OCRv2_det_infer.onnx', | |
'ch_ppocr_server_v2.0_det_infer.onnx'], | |
label='选择文本检测模型', | |
value='ch_PP-OCRv3_det_infer.onnx', | |
interactive=True) | |
rec_model_list = [v.name for v in Path('models/text_rec').iterdir()] | |
text_rec = gr.Dropdown(rec_model_list, | |
label='选择文本识别模型(包括中英文和多语言,欢迎提交更多模型)', | |
value='ch_PP-OCRv3_rec_infer.onnx', | |
interactive=True) | |
with gr.Row(): | |
input_img = gr.Image(type='filepath', label='Input') | |
out_img = gr.Image(type='filepath', label='Output') | |
out_json = gr.JSON(label='Rec Res') | |
out_log = gr.outputs.Textbox(type='text', label='Run Log') | |
button = gr.Button('Submit') | |
button.click(fn=inference, | |
inputs=[input_img, box_thresh, unclip_ratio, text_score, | |
text_det, text_rec], | |
outputs=[out_img, out_json, out_log]) | |
gr.Examples(examples=examples, | |
inputs=[input_img, box_thresh, unclip_ratio, text_score, | |
text_det, text_rec], | |
outputs=[out_img, out_json, out_log], fn=inference) | |
demo.launch(debug=True, enable_queue=True) | |