surya-ocr / app.py
sanil-55's picture
Create app.py
d948a30 verified
import gradio as gr
import json, os, copy
from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
#load models
#line detection, layout, order
det_model = load_detection_model()
det_processor = load_detection_processor()
layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
order_model = load_order_model()
order_processor = load_order_processor()
with open("languages.json", "r", encoding='utf-8') as file:
language_map = json.load(file)
def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None,
det_model=det_model, det_processor=det_processor):
assert langs or lang_file, "Must provide either langs or lang_file"
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages, start_page)
else:
images, names = load_from_file(input_path, max_pages, start_page)
langs = langs.split(",")
replace_lang_with_code(langs)
image_langs = [langs] * len(images)
_, lang_tokens = _tokenize("", get_unique_langs(image_langs))
rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need
rec_processor = load_recognition_processor()
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
return page_image
def layout_main(input_path, max_pages=None,
det_model=det_model, det_processor=det_processor,
model=layout_model, processor=layout_processor):
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages)
else:
images, names = load_from_file(input_path, max_pages)
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
polygons = [p.polygon for p in layout_pred.bboxes]
labels = [p.label for p in layout_pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
return bbox_image
def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor,
layout_model=layout_model, layout_processor=layout_processor,
det_model=det_model, det_processor=det_processor):
if os.path.isdir(input_path):
images, names = load_from_folder(input_path, max_pages)
else:
images, names = load_from_file(input_path, max_pages)
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
bboxes = []
for layout_pred in layout_predictions:
bbox = [l.bbox for l in layout_pred.bboxes]
bboxes.append(bbox)
order_predictions = batch_ordering(images, bboxes, model, processor)
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
polys = [l.polygon for l in order_pred.bboxes]
labels = [str(l.position) for l in order_pred.bboxes]
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
return bbox_image
def model1(image_path, languages):
langs = ""
if languages == [] or not languages:
langs = "English"
else:
for lang in languages:
langs += f"{lang},"
langs = langs[:-1]
annotated = ocr_main(image_path, langs=langs)
return annotated
def model2(image_path):
annotated = layout_main(image_path)
return annotated
def model3(image_path):
annotated = reading_main(image_path)
return annotated
with gr.Blocks() as demo:
gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>")
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image", sources="upload")
with gr.Row():
dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True)
with gr.Row():
btn1 = gr.Button("OCR", variant="primary")
btn2 = gr.Button("Layout", variant="primary")
btn3 = gr.Button("Reading Order", variant="primary")
with gr.Row():
clear = gr.ClearButton()
with gr.Column():
with gr.Tabs():
with gr.TabItem("OCR"):
output_image1 = gr.Image()
with gr.TabItem("Layout"):
output_image2 = gr.Image()
with gr.TabItem("Reading Order"):
output_image3 = gr.Image()
btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1)
btn2.click(fn=model2, inputs=[input_image], outputs=output_image2)
btn3.click(fn=model3, inputs=[input_image], outputs=output_image3)
clear.add(components=[input_image, output_image1, output_image2, output_image3])
demo.launch()