import gradio as gr import json, os, copy from surya.input.langs import replace_lang_with_code, get_unique_langs from surya.input.load import load_from_folder, load_from_file from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor from surya.model.recognition.model import load_model as load_recognition_model from surya.model.recognition.processor import load_processor as load_recognition_processor from surya.model.recognition.tokenizer import _tokenize from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image from surya.detection import batch_text_detection from surya.layout import batch_layout_detection from surya.model.ordering.model import load_model as load_order_model from surya.model.ordering.processor import load_processor as load_order_processor from surya.ordering import batch_ordering from surya.postprocessing.heatmap import draw_polys_on_image from surya.settings import settings #load models #line detection, layout, order det_model = load_detection_model() det_processor = load_detection_processor() layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) order_model = load_order_model() order_processor = load_order_processor() with open("languages.json", "r", encoding='utf-8') as file: language_map = json.load(file) def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None, det_model=det_model, det_processor=det_processor): assert langs or lang_file, "Must provide either langs or lang_file" if os.path.isdir(input_path): images, names = load_from_folder(input_path, max_pages, start_page) else: images, names = load_from_file(input_path, max_pages, start_page) langs = langs.split(",") replace_lang_with_code(langs) image_langs = [langs] * len(images) _, lang_tokens = _tokenize("", get_unique_langs(image_langs)) rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need rec_processor = load_recognition_processor() predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor) for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): bboxes = [l.bbox for l in pred.text_lines] pred_text = [l.text for l in pred.text_lines] page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs) return page_image def layout_main(input_path, max_pages=None, det_model=det_model, det_processor=det_processor, model=layout_model, processor=layout_processor): if os.path.isdir(input_path): images, names = load_from_folder(input_path, max_pages) else: images, names = load_from_file(input_path, max_pages) line_predictions = batch_text_detection(images, det_model, det_processor) layout_predictions = batch_layout_detection(images, model, processor, line_predictions) for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)): polygons = [p.polygon for p in layout_pred.bboxes] labels = [p.label for p in layout_pred.bboxes] bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) return bbox_image def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor, layout_model=layout_model, layout_processor=layout_processor, det_model=det_model, det_processor=det_processor): if os.path.isdir(input_path): images, names = load_from_folder(input_path, max_pages) else: images, names = load_from_file(input_path, max_pages) line_predictions = batch_text_detection(images, det_model, det_processor) layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) bboxes = [] for layout_pred in layout_predictions: bbox = [l.bbox for l in layout_pred.bboxes] bboxes.append(bbox) order_predictions = batch_ordering(images, bboxes, model, processor) for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)): polys = [l.polygon for l in order_pred.bboxes] labels = [str(l.position) for l in order_pred.bboxes] bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20) return bbox_image def model1(image_path, languages): langs = "" if languages == [] or not languages: langs = "English" else: for lang in languages: langs += f"{lang}," langs = langs[:-1] annotated = ocr_main(image_path, langs=langs) return annotated def model2(image_path): annotated = layout_main(image_path) return annotated def model3(image_path): annotated = reading_main(image_path) return annotated with gr.Blocks() as demo: gr.Markdown("

Surya - Image OCR/Layout/Reading Order

") with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(type="filepath", label="Input Image", sources="upload") with gr.Row(): dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True) with gr.Row(): btn1 = gr.Button("OCR", variant="primary") btn2 = gr.Button("Layout", variant="primary") btn3 = gr.Button("Reading Order", variant="primary") with gr.Row(): clear = gr.ClearButton() with gr.Column(): with gr.Tabs(): with gr.TabItem("OCR"): output_image1 = gr.Image() with gr.TabItem("Layout"): output_image2 = gr.Image() with gr.TabItem("Reading Order"): output_image3 = gr.Image() btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1) btn2.click(fn=model2, inputs=[input_image], outputs=output_image2) btn3.click(fn=model3, inputs=[input_image], outputs=output_image3) clear.add(components=[input_image, output_image1, output_image2, output_image3]) demo.launch()