sanil-55 commited on
Commit
d948a30
β€’
1 Parent(s): a99fd09

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json, os, copy
3
+
4
+ from surya.input.langs import replace_lang_with_code, get_unique_langs
5
+ from surya.input.load import load_from_folder, load_from_file
6
+ from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
7
+ from surya.model.recognition.model import load_model as load_recognition_model
8
+ from surya.model.recognition.processor import load_processor as load_recognition_processor
9
+ from surya.model.recognition.tokenizer import _tokenize
10
+ from surya.ocr import run_ocr
11
+ from surya.postprocessing.text import draw_text_on_image
12
+
13
+ from surya.detection import batch_text_detection
14
+ from surya.layout import batch_layout_detection
15
+
16
+ from surya.model.ordering.model import load_model as load_order_model
17
+ from surya.model.ordering.processor import load_processor as load_order_processor
18
+ from surya.ordering import batch_ordering
19
+ from surya.postprocessing.heatmap import draw_polys_on_image
20
+ from surya.settings import settings
21
+
22
+
23
+ #load models
24
+ #line detection, layout, order
25
+ det_model = load_detection_model()
26
+ det_processor = load_detection_processor()
27
+
28
+ layout_model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
29
+ layout_processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
30
+
31
+ order_model = load_order_model()
32
+ order_processor = load_order_processor()
33
+
34
+
35
+
36
+ with open("languages.json", "r", encoding='utf-8') as file:
37
+ language_map = json.load(file)
38
+
39
+ def ocr_main(input_path, max_pages=None, start_page=0, langs=None, lang_file=None,
40
+ det_model=det_model, det_processor=det_processor):
41
+
42
+ assert langs or lang_file, "Must provide either langs or lang_file"
43
+
44
+ if os.path.isdir(input_path):
45
+ images, names = load_from_folder(input_path, max_pages, start_page)
46
+ else:
47
+ images, names = load_from_file(input_path, max_pages, start_page)
48
+
49
+
50
+ langs = langs.split(",")
51
+ replace_lang_with_code(langs)
52
+ image_langs = [langs] * len(images)
53
+
54
+ _, lang_tokens = _tokenize("", get_unique_langs(image_langs))
55
+ rec_model = load_recognition_model(langs=lang_tokens) # Prune model moe layer to only include languages we need
56
+ rec_processor = load_recognition_processor()
57
+
58
+ predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
59
+
60
+ for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
61
+ bboxes = [l.bbox for l in pred.text_lines]
62
+ pred_text = [l.text for l in pred.text_lines]
63
+ page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
64
+ return page_image
65
+
66
+ def layout_main(input_path, max_pages=None,
67
+ det_model=det_model, det_processor=det_processor,
68
+ model=layout_model, processor=layout_processor):
69
+
70
+ if os.path.isdir(input_path):
71
+ images, names = load_from_folder(input_path, max_pages)
72
+
73
+ else:
74
+ images, names = load_from_file(input_path, max_pages)
75
+
76
+
77
+ line_predictions = batch_text_detection(images, det_model, det_processor)
78
+
79
+ layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
80
+
81
+
82
+ for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
83
+ polygons = [p.polygon for p in layout_pred.bboxes]
84
+ labels = [p.label for p in layout_pred.bboxes]
85
+ bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
86
+ return bbox_image
87
+
88
+ def reading_main(input_path, max_pages=None, model=order_model, processor=order_processor,
89
+ layout_model=layout_model, layout_processor=layout_processor,
90
+ det_model=det_model, det_processor=det_processor):
91
+
92
+ if os.path.isdir(input_path):
93
+ images, names = load_from_folder(input_path, max_pages)
94
+
95
+ else:
96
+ images, names = load_from_file(input_path, max_pages)
97
+
98
+
99
+ line_predictions = batch_text_detection(images, det_model, det_processor)
100
+ layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
101
+ bboxes = []
102
+ for layout_pred in layout_predictions:
103
+ bbox = [l.bbox for l in layout_pred.bboxes]
104
+ bboxes.append(bbox)
105
+
106
+ order_predictions = batch_ordering(images, bboxes, model, processor)
107
+
108
+ for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
109
+ polys = [l.polygon for l in order_pred.bboxes]
110
+ labels = [str(l.position) for l in order_pred.bboxes]
111
+ bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
112
+ return bbox_image
113
+
114
+
115
+
116
+
117
+ def model1(image_path, languages):
118
+ langs = ""
119
+ if languages == [] or not languages:
120
+ langs = "English"
121
+ else:
122
+ for lang in languages:
123
+ langs += f"{lang},"
124
+ langs = langs[:-1]
125
+
126
+ annotated = ocr_main(image_path, langs=langs)
127
+ return annotated
128
+
129
+ def model2(image_path):
130
+
131
+ annotated = layout_main(image_path)
132
+ return annotated
133
+
134
+ def model3(image_path):
135
+
136
+ annotated = reading_main(image_path)
137
+ return annotated
138
+
139
+
140
+ with gr.Blocks() as demo:
141
+ gr.Markdown("<center><h1>Surya - Image OCR/Layout/Reading Order</h1></center>")
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ with gr.Row():
146
+ input_image = gr.Image(type="filepath", label="Input Image", sources="upload")
147
+ with gr.Row():
148
+ dropdown = gr.Dropdown(label="Select Languages for OCR", choices=list(language_map.keys()), multiselect=True, value=["English"], interactive=True)
149
+ with gr.Row():
150
+ btn1 = gr.Button("OCR", variant="primary")
151
+ btn2 = gr.Button("Layout", variant="primary")
152
+ btn3 = gr.Button("Reading Order", variant="primary")
153
+ with gr.Row():
154
+ clear = gr.ClearButton()
155
+
156
+ with gr.Column():
157
+ with gr.Tabs():
158
+ with gr.TabItem("OCR"):
159
+ output_image1 = gr.Image()
160
+ with gr.TabItem("Layout"):
161
+ output_image2 = gr.Image()
162
+ with gr.TabItem("Reading Order"):
163
+ output_image3 = gr.Image()
164
+
165
+ btn1.click(fn=model1, inputs=[input_image, dropdown], outputs=output_image1)
166
+ btn2.click(fn=model2, inputs=[input_image], outputs=output_image2)
167
+ btn3.click(fn=model3, inputs=[input_image], outputs=output_image3)
168
+ clear.add(components=[input_image, output_image1, output_image2, output_image3])
169
+
170
+ demo.launch()