Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json, os, copy
|
3 |
+
|
4 |
+
from surya.input.langs import replace_lang_with_code, get_unique_langs
|
5 |
+
from surya.input.load import load_from_folder, load_from_file
|
6 |
+
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
|
7 |
+
from surya.model.recognition.model import load_model as load_recognition_model
|
8 |
+
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
9 |
+
from surya.model.recognition.tokenizer import _tokenize
|
10 |
+
from surya.ocr import run_ocr
|
11 |
+
from surya.postprocessing.text import draw_text_on_image
|
12 |
+
|
13 |
+
from surya.detection import batch_text_detection
|
14 |
+
from surya.layout import batch_layout_detection
|
15 |
+
|
16 |
+
from surya.model.ordering.model import load_model as load_order_model
|
17 |
+
from surya.model.ordering.processor import load_processor as load_order_processor
|
18 |
+
from surya.ordering import batch_ordering
|
19 |
+
from surya.postprocessing.heatmap import draw_polys_on_image
|
20 |
+
from surya.settings import settings
|
21 |
+
|
22 |
+
|
23 |
+
#load models
|
24 |
+
#line detection, layout, order
|
25 |
+
det_model = load_detection_model()
|
26 |
+
det_processor = load_detection_processor()
|
27 |
+
|
28 |
+
layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
29 |
+
layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
30 |
+
|
31 |
+
order_model = load_order_model()
|
32 |
+
order_processor = load_order_processor()
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
with open("languages.json", "r", encoding='utf-8') as file:
|
37 |
+
language_map = json.load(file)
|
38 |
+
|
39 |
+
def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None,
|
40 |
+
det_model=det_model, det_processor=det_processor):
|
41 |
+
|
42 |
+
assert langs or lang_file, "Must provide either langs or lang_file"
|
43 |
+
|
44 |
+
if os.path.isdir(input_path):
|
45 |
+
images, names = load_from_folder(input_path, max_pages, start_page)
|
46 |
+
else:
|
47 |
+
images, names = load_from_file(input_path, max_pages, start_page)
|
48 |
+
|
49 |
+
|
50 |
+
langs = langs.split(",")
|
51 |
+
replace_lang_with_code(langs)
|
52 |
+
image_langs = [langs] * len(images)
|
53 |
+
|
54 |
+
_, lang_tokens = _tokenize("", get_unique_langs(image_langs))
|
55 |
+
rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need
|
56 |
+
rec_processor = load_recognition_processor()
|
57 |
+
|
58 |
+
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
|
59 |
+
|
60 |
+
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
|
61 |
+
bboxes = [l.bbox for l in pred.text_lines]
|
62 |
+
pred_text = [l.text for l in pred.text_lines]
|
63 |
+
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
|
64 |
+
return page_image
|
65 |
+
|
66 |
+
def layout_main(input_path, max_pages=None,
|
67 |
+
det_model=det_model, det_processor=det_processor,
|
68 |
+
model=layout_model, processor=layout_processor):
|
69 |
+
|
70 |
+
if os.path.isdir(input_path):
|
71 |
+
images, names = load_from_folder(input_path, max_pages)
|
72 |
+
|
73 |
+
else:
|
74 |
+
images, names = load_from_file(input_path, max_pages)
|
75 |
+
|
76 |
+
|
77 |
+
line_predictions = batch_text_detection(images, det_model, det_processor)
|
78 |
+
|
79 |
+
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
|
80 |
+
|
81 |
+
|
82 |
+
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
|
83 |
+
polygons = [p.polygon for p in layout_pred.bboxes]
|
84 |
+
labels = [p.label for p in layout_pred.bboxes]
|
85 |
+
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
|
86 |
+
return bbox_image
|
87 |
+
|
88 |
+
def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor,
|
89 |
+
layout_model=layout_model, layout_processor=layout_processor,
|
90 |
+
det_model=det_model, det_processor=det_processor):
|
91 |
+
|
92 |
+
if os.path.isdir(input_path):
|
93 |
+
images, names = load_from_folder(input_path, max_pages)
|
94 |
+
|
95 |
+
else:
|
96 |
+
images, names = load_from_file(input_path, max_pages)
|
97 |
+
|
98 |
+
|
99 |
+
line_predictions = batch_text_detection(images, det_model, det_processor)
|
100 |
+
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
|
101 |
+
bboxes = []
|
102 |
+
for layout_pred in layout_predictions:
|
103 |
+
bbox = [l.bbox for l in layout_pred.bboxes]
|
104 |
+
bboxes.append(bbox)
|
105 |
+
|
106 |
+
order_predictions = batch_ordering(images, bboxes, model, processor)
|
107 |
+
|
108 |
+
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
|
109 |
+
polys = [l.polygon for l in order_pred.bboxes]
|
110 |
+
labels = [str(l.position) for l in order_pred.bboxes]
|
111 |
+
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
|
112 |
+
return bbox_image
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def model1(image_path, languages):
|
118 |
+
langs = ""
|
119 |
+
if languages == [] or not languages:
|
120 |
+
langs = "English"
|
121 |
+
else:
|
122 |
+
for lang in languages:
|
123 |
+
langs += f"{lang},"
|
124 |
+
langs = langs[:-1]
|
125 |
+
|
126 |
+
annotated = ocr_main(image_path, langs=langs)
|
127 |
+
return annotated
|
128 |
+
|
129 |
+
def model2(image_path):
|
130 |
+
|
131 |
+
annotated = layout_main(image_path)
|
132 |
+
return annotated
|
133 |
+
|
134 |
+
def model3(image_path):
|
135 |
+
|
136 |
+
annotated = reading_main(image_path)
|
137 |
+
return annotated
|
138 |
+
|
139 |
+
|
140 |
+
with gr.Blocks() as demo:
|
141 |
+
gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>")
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column():
|
145 |
+
with gr.Row():
|
146 |
+
input_image = gr.Image(type="filepath", label="Input Image", sources="upload")
|
147 |
+
with gr.Row():
|
148 |
+
dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True)
|
149 |
+
with gr.Row():
|
150 |
+
btn1 = gr.Button("OCR", variant="primary")
|
151 |
+
btn2 = gr.Button("Layout", variant="primary")
|
152 |
+
btn3 = gr.Button("Reading Order", variant="primary")
|
153 |
+
with gr.Row():
|
154 |
+
clear = gr.ClearButton()
|
155 |
+
|
156 |
+
with gr.Column():
|
157 |
+
with gr.Tabs():
|
158 |
+
with gr.TabItem("OCR"):
|
159 |
+
output_image1 = gr.Image()
|
160 |
+
with gr.TabItem("Layout"):
|
161 |
+
output_image2 = gr.Image()
|
162 |
+
with gr.TabItem("Reading Order"):
|
163 |
+
output_image3 = gr.Image()
|
164 |
+
|
165 |
+
btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1)
|
166 |
+
btn2.click(fn=model2, inputs=[input_image], outputs=output_image2)
|
167 |
+
btn3.click(fn=model3, inputs=[input_image], outputs=output_image3)
|
168 |
+
clear.add(components=[input_image, output_image1, output_image2, output_image3])
|
169 |
+
|
170 |
+
demo.launch()
|