File size: 6,936 Bytes
970b7b6
 
29fbcb4
970b7b6
 
 
07944e2
970b7b6
29fbcb4
970b7b6
 
 
 
 
 
 
 
3ce9f40
 
970b7b6
cf00bcc
970b7b6
07944e2
 
 
970b7b6
 
 
 
a4c9237
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7926f67
970b7b6
a9f6f8b
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c01649
 
3867b35
 
ed3c65a
 
3867b35
ed3c65a
ee247bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed3c65a
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de53dd1
970b7b6
 
 
 
de53dd1
970b7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dc468e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
from huggingface_hub import login
import gradio as gr
import numpy as np
import onnxruntime
import torch
import time
import os

from plotting_functions import PlotHTR
from segment_image import SegmentImage
from onnx_text_recognition import TextRecognition


LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

def get_segmenter():
    """Initialize segmentation class."""
    try:
        segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, 
                            device='cuda:0', 
                            line_iou=0.3,
                            region_iou=0.5,
                            line_overlap=0.5,
                            line_nms_iou=0.7,  
                            region_nms_iou=0.3,
                            line_conf_threshold=0.25, 
                            region_conf_threshold=0.5,  
                            region_model_path=REGION_MODEL_PATH, 
                            order_regions=True, 
                            region_half_precision=False, 
                            line_half_precision=False)
        return segmenter
    except Exception as e:
        print('Failed to initialize SegmentImage class: %s' % e)

def get_recognizer():
    """Initialize text recognition class."""
    try:
        recognizer = TextRecognition(
                        processor_path = TROCR_PROCESSOR_PATH, 
                        model_path = TROCR_MODEL_PATH, 
                        device = 'cuda:0', 
                        half_precision = True,
                        line_threshold = 10
                    )
        return recognizer
    except Exception as e:
        print('Failed to initialize TextRecognition class: %s' % e)

segmenter = get_segmenter()
recognizer = get_recognizer()
plotter = PlotHTR()

color_codes = """**Text region type:** <br>
        Paragraph ![#EE1289](https://placehold.co/15x15/EE1289/EE1289.png) 
        Marginalia ![#00C957](https://placehold.co/15x15/00C957/00C957.png) 
        Page number ![#0000FF](https://placehold.co/15x15/0000FF/0000FF.png)"""

def merge_lines(segment_predictions):
    img_lines = []
    for region in segment_predictions:
        img_lines += region['lines']
    return img_lines

def get_text_predictions(image, segment_predictions, recognizer):
    """Collects text prediction data into dicts based on detected text regions."""
    img_lines = merge_lines(segment_predictions)
    height, width = segment_predictions[0]['img_shape']
    # Process all lines of an image
    texts = recognizer.process_lines(img_lines, image, height, width)
    return texts

# Run demo code
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo:
    gr.Markdown("# Multicentury HTR Demo")
    gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition.
    The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents
    from 17th, 18th, 19th and 20th centuries. 
    
    Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window
    on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs.
    Best results are obtained when using high quality scans of documents with a regular layout.

    Please note that this is a demo. 24/7 functionality is not quaranteed.

    # Monen vuosisadan käsialantunnistus malli
    
    Käsialantunnistus putkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin 
    käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1600-luvulta 1900-luvulle. 

    Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image* 
    painikkeesta ja kuva on prosessoitu tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja 
    tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla,
    joissa on normaalin kirjan mukainen taitto.  
    
    Huom! Tämä on demo sovellus. Ympärivuorokautista toimivuutta ei luvata.
    """)
    
    with gr.Tab("Text content"):
        with gr.Row():
            input_img = gr.Image(label="Input image", type="pil")
            textbox = gr.Textbox(label="Predicted text content", lines=10)
        button = gr.Button("Process image")
        processing_time = gr.Markdown()
    with gr.Tab("Text regions"):
        region_img = gr.Image(label="Predicted text regions", type="numpy")
        gr.Markdown(color_codes)
    with gr.Tab("Text lines"):
        line_img = gr.Image(label="Predicted text lines", type="numpy")
        gr.Markdown(color_codes)
        
    def run_pipeline(image):
        # Predict region and line segments
        start = time.time()
        segment_predictions = segmenter.get_segmentation(image)
        print('segmentation ok')
        if segment_predictions:
            region_plot = plotter.plot_regions(segment_predictions, image)
            line_plot = plotter.plot_lines(segment_predictions, image)
            text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
            print('text pred ok')
            text = "\n".join(text_predictions)
            end = time.time()
            proc_time = end - start
            proc_time_str = f"Processing time: {proc_time:.4f}s"
            return {
                region_img: region_plot, 
                line_img: line_plot, 
                textbox: text,
                processing_time: proc_time_str
                }
        else:
            end = time.time()
            proc_time = end - start
            proc_time_str = f"Processing time: {proc_time:.4f}s"
            return {
                region_img: None, 
                line_img: None, 
                textbox: None,
                processing_time: proc_time_str
                }


    button.click(fn=run_pipeline, 
                 inputs=input_img, 
                 outputs=[region_img, line_img, textbox, processing_time])

if __name__ == "__main__":
    demo.queue()
    demo.launch(show_error=True)