RapidOCRv2 / app.py
SWHL's picture
Update app.py
9d23789 verified
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from enum import Enum
from pathlib import Path
from typing import List, Union
import gradio as gr
import numpy as np
from rapidocr import RapidOCR
class InferEngine(Enum):
ort = "ONNXRuntime"
vino = "OpenVino"
paddle = "PaddlePaddle"
torch = "PyTorch"
def get_ocr_engine(infer_engine: str, lang_det: str, lang_rec: str) -> RapidOCR:
engine_mapping = {
InferEngine.vino.value: "with_openvino",
InferEngine.paddle.value: "with_paddle",
InferEngine.torch.value: "with_torch",
}
param_key = engine_mapping.get(infer_engine, "with_onnx")
return RapidOCR(
params={
f"Global.{param_key}": True,
"Global.lang_det": lang_det,
"Global.lang_rec": lang_rec,
}
)
def get_ocr_result(
img: np.ndarray,
text_score,
box_thresh,
unclip_ratio,
lang_det,
lang_rec,
infer_engine,
is_word: str,
):
return_word_box = True if is_word == "Yes" else False
ocr_engine = get_ocr_engine(infer_engine, lang_det=lang_det, lang_rec=lang_rec)
ocr_result = ocr_engine(
img,
text_score=text_score,
box_thresh=box_thresh,
unclip_ratio=unclip_ratio,
return_word_box=return_word_box,
)
vis_img = ocr_result.vis()
if return_word_box:
txts, scores, _ = list(zip(*ocr_result.word_results))
ocr_txts = [[i, txt, score] for i, (txt, score) in enumerate(zip(txts, scores))]
return vis_img, ocr_txts, ocr_result.elapse
ocr_txts = [
[i, txt, score]
for i, (txt, score) in enumerate(zip(ocr_result.txts, ocr_result.scores))
]
return vis_img, ocr_txts, ocr_result.elapse
def create_examples() -> List[List[Union[str, float]]]:
DEFAULT_VALUES = [0.5, 0.5, 1.6, "ch_mobile", "ch_mobile", "ONNXRuntime", "No"]
image_specs = [
("images/ch_en_num.jpg", {}),
("images/japan.jpg", {3: "multi_mobile", 4: "japan_mobile"}),
("images/korean.jpg", {3: "multi_mobile", 4: "korean_mobile"}),
("images/air_ticket.jpg", {}),
("images/car_plate.jpeg", {}),
("images/train_ticket.jpeg", {}),
]
examples = []
for image_path, overrides in image_specs:
example = DEFAULT_VALUES.copy()
example.insert(0, image_path)
for index, value in overrides.items():
example[index + 1] = value
examples.append(example)
return examples
infer_engine_list = [InferEngine[v].value for v in InferEngine.__members__]
lang_det_list = ["ch_mobile", "ch_server", "en_mobile", "en_server", "multi_mobile"]
lang_rec_list = [
"ch_mobile",
"ch_server",
"chinese_cht",
"en_mobile",
"ar_mobile",
"cyrillic_mobile",
"devanagari_mobile",
"japan_mobile",
"ka_mobile",
"korean_mobile",
"latin_mobile",
"ta_mobile",
"te_mobile",
]
custom_css = """
body {font-family: body {font-family: 'Helvetica Neue', Helvetica;}
.gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;}
.gr-button:hover {background-color: #45a049;}
.gr-textbox {margin-bottom: 15px;}
.example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;}
.example-button:hover {background-color: #FF4500;}
.tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;}
.tall-radio label {font-size: 16px;}
.output-image, .input-image, .image-preview {height: 300px !important}
"""
with gr.Blocks(
title="Rapid⚡OCR Demo", css="custom_css", theme=gr.themes.Soft()
) as demo:
gr.HTML(
"""
<h1 style='text-align: center;font-size:40px'>Rapid⚡OCR</h1>
<div style="display: flex; justify-content: center; gap: 10px;">
<a href=""><img src="https://img.shields.io/badge/Python->=3.6-aff.svg"></a>
<a href="https://rapidai.github.io/RapidOCRDocs"><img src="https://img.shields.io/badge/Docs-link-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
<a href="https://pepy.tech/project/rapidocr"><img src="https://static.pepy.tech/personalized-badge/rapidocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20rapidocr"></a>
<a href="https://pypi.org/project/rapidocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr"></a>
<a href="https://github.com/RapidAI/RapidOCR"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a>
</div>
"""
)
with gr.Row():
text_score = gr.Slider(
label="text_score",
minimum=0,
maximum=1.0,
value=0.5,
step=0.1,
info="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
)
box_thresh = gr.Slider(
label="box_thresh",
minimum=0,
maximum=1.0,
value=0.5,
step=0.1,
info="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5",
)
unclip_ratio = gr.Slider(
label="unclip_ratio",
minimum=1.5,
maximum=2.0,
value=1.6,
step=0.1,
info="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6",
)
with gr.Row():
select_infer_engine = gr.Dropdown(
choices=infer_engine_list,
label="Infer Engine (推理引擎)",
value="ONNXRuntime",
interactive=True,
)
lang_det = gr.Dropdown(
choices=lang_det_list,
label="Det model (文本检测模型)",
value=lang_det_list[0],
interactive=True,
)
lang_rec = gr.Dropdown(
choices=lang_rec_list,
label="Rec model (文本识别模型)",
value=lang_rec_list[0],
interactive=True,
)
is_word = gr.Radio(
["Yes", "No"], label="Return word box (返回单字符)", value="No"
)
img_input = gr.Image(label="Upload or Select Image", sources="upload")
run_btn = gr.Button("Run")
img_output = gr.Image(label="Output Image")
elapse = gr.Textbox(label="Elapse(s)")
ocr_results = gr.Dataframe(
label="OCR Txts",
headers=["Index", "Txt", "Score"],
datatype=["number", "str", "number"],
show_copy_button=True,
)
ocr_inputs = [
img_input,
text_score,
box_thresh,
unclip_ratio,
lang_det,
lang_rec,
select_infer_engine,
is_word,
]
run_btn.click(
get_ocr_result, inputs=ocr_inputs, outputs=[img_output, ocr_results, elapse]
)
examples = gr.Examples(
examples=create_examples(),
examples_per_page=5,
inputs=ocr_inputs,
fn=get_ocr_result,
outputs=[img_output, ocr_results, elapse],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(debug=True)