MikkoLipsanen commited on
Commit
970b7b6
·
verified ·
1 Parent(s): 36e7edb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optimum.onnxruntime import ORTModelForVision2Seq
2
+ from transformers import TrOCRProcessor
3
+ from ultralytics import YOLO
4
+ import gradio as gr
5
+ import numpy as np
6
+ import onnxruntime
7
+ import time
8
+
9
+ from plotting_functions import PlotHTR
10
+ from segment_image import SegmentImage
11
+ from onnx_text_recognition import TextRecognition
12
+
13
+
14
+ LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
15
+ REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
16
+ TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_processor/"
17
+ TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_onnx/"
18
+
19
+
20
+ def get_segmenter():
21
+ """Initialize segmentation class."""
22
+ try:
23
+ segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
24
+ device='cpu',
25
+ line_iou=0.3,
26
+ region_iou=0.5,
27
+ line_overlap=0.5,
28
+ line_nms_iou=0.7,
29
+ region_nms_iou=0.3,
30
+ line_conf_threshold=0.25,
31
+ region_conf_threshold=0.5,
32
+ region_model_path=REGION_MODEL_PATH,
33
+ order_regions=True,
34
+ region_half_precision=False,
35
+ line_half_precision=False)
36
+ return segmenter
37
+ except Exception as e:
38
+ print('Failed to initialize SegmentImage class: %s' % e)
39
+
40
+ def get_recognizer():
41
+ """Initialize text recognition class."""
42
+ try:
43
+ recognizer = TextRecognition(
44
+ processor_path = TROCR_PROCESSOR_PATH,
45
+ model_path = TROCR_MODEL_PATH,
46
+ device = 'cpu',
47
+ half_precision = True,
48
+ line_threshold = 100
49
+ )
50
+ return recognizer
51
+ except Exception as e:
52
+ print('Failed to initialize TextRecognition class: %s' % e)
53
+
54
+ segmenter = get_segmenter()
55
+ recognizer = get_recognizer()
56
+ plotter = PlotHTR()
57
+
58
+ color_codes = """**Text region type:** <br>
59
+ Paragraph ![#EE1289](https://placehold.co/15x15/EE1289/EE1289.png)
60
+ Marginalia ![#00C957](https://placehold.co/15x15/00C957/00C957.png)
61
+ Page number ![#0000FF](https://placehold.co/15x15/0000FF/0000FF.png)"""
62
+
63
+ def merge_lines(segment_predictions):
64
+ img_lines = []
65
+ for region in segment_predictions:
66
+ img_lines += region['lines']
67
+ return img_lines
68
+
69
+ def get_text_predictions(image, segment_predictions, recognizer):
70
+ """Collects text prediction data into dicts based on detected text regions."""
71
+ img_lines = merge_lines(segment_predictions)
72
+ height, width = segment_predictions[0]['img_shape']
73
+ # Process all lines of an image
74
+ texts = recognizer.process_lines(img_lines, image, height, width)
75
+ return texts
76
+
77
+ # Run demo code
78
+ with gr.Blocks(theme=gr.themes.Monochrome(), title="HTR demo") as demo:
79
+ gr.Markdown("# HTR demo")
80
+ with gr.Tab("Text content"):
81
+ with gr.Row():
82
+ input_img = gr.Image(label="Input image", type="pil")
83
+ textbox = gr.Textbox(label="Predicted text content", lines=10)
84
+ button = gr.Button("Process image")
85
+ processing_time = gr.Markdown()
86
+ with gr.Tab("Text regions"):
87
+ region_img = gr.Image(label="Predicted text regions", type="numpy")
88
+ gr.Markdown(color_codes)
89
+ with gr.Tab("Text lines"):
90
+ line_img = gr.Image(label="Predicted text lines", type="numpy")
91
+ gr.Markdown(color_codes)
92
+
93
+ def run_pipeline(image):
94
+ # Predict region and line segments
95
+ start = time.time()
96
+ segment_predictions = segmenter.get_segmentation(image)
97
+ if segment_predictions:
98
+ region_plot = plotter.plot_regions(segment_predictions, image)
99
+ line_plot = plotter.plot_lines(segment_predictions, image)
100
+ text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
101
+ text = "\n".join(text_predictions)
102
+ end = time.time()
103
+ proc_time = end - start
104
+ proc_time_str = f"Processing time: {proc_time:.4f}s"
105
+ return {
106
+ region_img: region_plot,
107
+ line_img: line_plot,
108
+ textbox: text,
109
+ processing_time: proc_time_str
110
+ }
111
+ else:
112
+ end = time.time()
113
+ proc_time = end - start
114
+ proc_time_str = f"Processing time: {proc_time:.4f}s"
115
+ return {
116
+ region_img: None,
117
+ line_img: None,
118
+ textbox: None,
119
+ processing_time: proc_time_str
120
+ }
121
+
122
+
123
+ button.click(fn=run_pipeline,
124
+ inputs=input_img,
125
+ outputs=[region_img, line_img, textbox, processing_time])
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()