|
import gradio as gr |
|
import json, os, copy |
|
|
|
from surya.input.langs import replace_lang_with_code, get_unique_langs |
|
from surya.input.load import load_from_folder, load_from_file |
|
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor |
|
from surya.model.recognition.model import load_model as load_recognition_model |
|
from surya.model.recognition.processor import load_processor as load_recognition_processor |
|
from surya.model.recognition.tokenizer import _tokenize |
|
from surya.ocr import run_ocr |
|
from surya.postprocessing.text import draw_text_on_image |
|
|
|
from surya.detection import batch_text_detection |
|
from surya.layout import batch_layout_detection |
|
|
|
from surya.model.ordering.model import load_model as load_order_model |
|
from surya.model.ordering.processor import load_processor as load_order_processor |
|
from surya.ordering import batch_ordering |
|
from surya.postprocessing.heatmap import draw_polys_on_image |
|
from surya.settings import settings |
|
|
|
|
|
|
|
|
|
det_model = load_detection_model() |
|
det_processor = load_detection_processor() |
|
|
|
layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) |
|
layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) |
|
|
|
order_model = load_order_model() |
|
order_processor = load_order_processor() |
|
|
|
|
|
|
|
with open("languages.json", "r", encoding='utf-8') as file: |
|
language_map = json.load(file) |
|
|
|
def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None, |
|
det_model=det_model, det_processor=det_processor): |
|
|
|
assert langs or lang_file, "Must provide either langs or lang_file" |
|
|
|
if os.path.isdir(input_path): |
|
images, names = load_from_folder(input_path, max_pages, start_page) |
|
else: |
|
images, names = load_from_file(input_path, max_pages, start_page) |
|
|
|
|
|
langs = langs.split(",") |
|
replace_lang_with_code(langs) |
|
image_langs = [langs] * len(images) |
|
|
|
_, lang_tokens = _tokenize("", get_unique_langs(image_langs)) |
|
rec_model = load_recognition_model(langs=lang_tokens) |
|
rec_processor = load_recognition_processor() |
|
|
|
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor) |
|
|
|
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): |
|
bboxes = [l.bbox for l in pred.text_lines] |
|
pred_text = [l.text for l in pred.text_lines] |
|
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs) |
|
return page_image |
|
|
|
def layout_main(input_path, max_pages=None, |
|
det_model=det_model, det_processor=det_processor, |
|
model=layout_model, processor=layout_processor): |
|
|
|
if os.path.isdir(input_path): |
|
images, names = load_from_folder(input_path, max_pages) |
|
|
|
else: |
|
images, names = load_from_file(input_path, max_pages) |
|
|
|
|
|
line_predictions = batch_text_detection(images, det_model, det_processor) |
|
|
|
layout_predictions = batch_layout_detection(images, model, processor, line_predictions) |
|
|
|
|
|
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)): |
|
polygons = [p.polygon for p in layout_pred.bboxes] |
|
labels = [p.label for p in layout_pred.bboxes] |
|
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) |
|
return bbox_image |
|
|
|
def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor, |
|
layout_model=layout_model, layout_processor=layout_processor, |
|
det_model=det_model, det_processor=det_processor): |
|
|
|
if os.path.isdir(input_path): |
|
images, names = load_from_folder(input_path, max_pages) |
|
|
|
else: |
|
images, names = load_from_file(input_path, max_pages) |
|
|
|
|
|
line_predictions = batch_text_detection(images, det_model, det_processor) |
|
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) |
|
bboxes = [] |
|
for layout_pred in layout_predictions: |
|
bbox = [l.bbox for l in layout_pred.bboxes] |
|
bboxes.append(bbox) |
|
|
|
order_predictions = batch_ordering(images, bboxes, model, processor) |
|
|
|
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)): |
|
polys = [l.polygon for l in order_pred.bboxes] |
|
labels = [str(l.position) for l in order_pred.bboxes] |
|
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20) |
|
return bbox_image |
|
|
|
|
|
|
|
|
|
def model1(image_path, languages): |
|
langs = "" |
|
if languages == [] or not languages: |
|
langs = "English" |
|
else: |
|
for lang in languages: |
|
langs += f"{lang}," |
|
langs = langs[:-1] |
|
|
|
annotated = ocr_main(image_path, langs=langs) |
|
return annotated |
|
|
|
def model2(image_path): |
|
|
|
annotated = layout_main(image_path) |
|
return annotated |
|
|
|
def model3(image_path): |
|
|
|
annotated = reading_main(image_path) |
|
return annotated |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
input_image = gr.Image(type="filepath", label="Input Image", sources="upload") |
|
with gr.Row(): |
|
dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True) |
|
with gr.Row(): |
|
btn1 = gr.Button("OCR", variant="primary") |
|
btn2 = gr.Button("Layout", variant="primary") |
|
btn3 = gr.Button("Reading Order", variant="primary") |
|
with gr.Row(): |
|
clear = gr.ClearButton() |
|
|
|
with gr.Column(): |
|
with gr.Tabs(): |
|
with gr.TabItem("OCR"): |
|
output_image1 = gr.Image() |
|
with gr.TabItem("Layout"): |
|
output_image2 = gr.Image() |
|
with gr.TabItem("Reading Order"): |
|
output_image3 = gr.Image() |
|
|
|
btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1) |
|
btn2.click(fn=model2, inputs=[input_image], outputs=output_image2) |
|
btn3.click(fn=model3, inputs=[input_image], outputs=output_image3) |
|
clear.add(components=[input_image, output_image1, output_image2, output_image3]) |
|
|
|
demo.launch() |