File size: 5,060 Bytes
970b7b6
 
29fbcb4
970b7b6
 
 
 
 
29fbcb4
970b7b6
 
 
 
 
 
 
 
cf00bcc
 
970b7b6
cf00bcc
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c01649
 
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
from huggingface_hub import login
from ultralytics import YOLO
import gradio as gr
import numpy as np
import onnxruntime
import time
import os

from plotting_functions import PlotHTR
from segment_image import SegmentImage
from onnx_text_recognition import TextRecognition


LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_processor"
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_onnx"

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

def get_segmenter():
    """Initialize segmentation class."""
    try:
        segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, 
                            device='cpu', 
                            line_iou=0.3,
                            region_iou=0.5,
                            line_overlap=0.5,
                            line_nms_iou=0.7,  
                            region_nms_iou=0.3,
                            line_conf_threshold=0.25, 
                            region_conf_threshold=0.5,  
                            region_model_path=REGION_MODEL_PATH, 
                            order_regions=True, 
                            region_half_precision=False, 
                            line_half_precision=False)
        return segmenter
    except Exception as e:
        print('Failed to initialize SegmentImage class: %s' % e)

def get_recognizer():
    """Initialize text recognition class."""
    try:
        recognizer = TextRecognition(
                        processor_path = TROCR_PROCESSOR_PATH, 
                        model_path = TROCR_MODEL_PATH, 
                        device = 'cpu', 
                        half_precision = True,
                        line_threshold = 100
                    )
        return recognizer
    except Exception as e:
        print('Failed to initialize TextRecognition class: %s' % e)

segmenter = get_segmenter()
recognizer = get_recognizer()
plotter = PlotHTR()

color_codes = """**Text region type:** <br>
        Paragraph ![#EE1289](https://placehold.co/15x15/EE1289/EE1289.png) 
        Marginalia ![#00C957](https://placehold.co/15x15/00C957/00C957.png) 
        Page number ![#0000FF](https://placehold.co/15x15/0000FF/0000FF.png)"""

def merge_lines(segment_predictions):
    img_lines = []
    for region in segment_predictions:
        img_lines += region['lines']
    return img_lines

def get_text_predictions(image, segment_predictions, recognizer):
    """Collects text prediction data into dicts based on detected text regions."""
    img_lines = merge_lines(segment_predictions)
    height, width = segment_predictions[0]['img_shape']
    # Process all lines of an image
    texts = recognizer.process_lines(img_lines, image, height, width)
    return texts

# Run demo code
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo:
    gr.Markdown("# Multicentury HTR Demo")
    with gr.Tab("Text content"):
        with gr.Row():
            input_img = gr.Image(label="Input image", type="pil")
            textbox = gr.Textbox(label="Predicted text content", lines=10)
        button = gr.Button("Process image")
        processing_time = gr.Markdown()
    with gr.Tab("Text regions"):
        region_img = gr.Image(label="Predicted text regions", type="numpy")
        gr.Markdown(color_codes)
    with gr.Tab("Text lines"):
        line_img = gr.Image(label="Predicted text lines", type="numpy")
        gr.Markdown(color_codes)
        
    def run_pipeline(image):
        # Predict region and line segments
        start = time.time()
        segment_predictions = segmenter.get_segmentation(image)
        if segment_predictions:
            region_plot = plotter.plot_regions(segment_predictions, image)
            line_plot = plotter.plot_lines(segment_predictions, image)
            text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
            text = "\n".join(text_predictions)
            end = time.time()
            proc_time = end - start
            proc_time_str = f"Processing time: {proc_time:.4f}s"
            return {
                region_img: region_plot, 
                line_img: line_plot, 
                textbox: text,
                processing_time: proc_time_str
                }
        else:
            end = time.time()
            proc_time = end - start
            proc_time_str = f"Processing time: {proc_time:.4f}s"
            return {
                region_img: None, 
                line_img: None, 
                textbox: None,
                processing_time: proc_time_str
                }


    button.click(fn=run_pipeline, 
                 inputs=input_img, 
                 outputs=[region_img, line_img, textbox, processing_time])

if __name__ == "__main__":
    demo.launch()