File size: 6,936 Bytes
970b7b6 29fbcb4 970b7b6 07944e2 970b7b6 29fbcb4 970b7b6 3ce9f40 970b7b6 cf00bcc 970b7b6 07944e2 970b7b6 a4c9237 970b7b6 7926f67 970b7b6 a9f6f8b 970b7b6 7c01649 3867b35 ed3c65a 3867b35 ed3c65a ee247bb ed3c65a 970b7b6 de53dd1 970b7b6 de53dd1 970b7b6 6dc468e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
from huggingface_hub import login
import gradio as gr
import numpy as np
import onnxruntime
import torch
import time
import os
from plotting_functions import PlotHTR
from segment_image import SegmentImage
from onnx_text_recognition import TextRecognition
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
def get_segmenter():
"""Initialize segmentation class."""
try:
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
device='cuda:0',
line_iou=0.3,
region_iou=0.5,
line_overlap=0.5,
line_nms_iou=0.7,
region_nms_iou=0.3,
line_conf_threshold=0.25,
region_conf_threshold=0.5,
region_model_path=REGION_MODEL_PATH,
order_regions=True,
region_half_precision=False,
line_half_precision=False)
return segmenter
except Exception as e:
print('Failed to initialize SegmentImage class: %s' % e)
def get_recognizer():
"""Initialize text recognition class."""
try:
recognizer = TextRecognition(
processor_path = TROCR_PROCESSOR_PATH,
model_path = TROCR_MODEL_PATH,
device = 'cuda:0',
half_precision = True,
line_threshold = 10
)
return recognizer
except Exception as e:
print('Failed to initialize TextRecognition class: %s' % e)
segmenter = get_segmenter()
recognizer = get_recognizer()
plotter = PlotHTR()
color_codes = """**Text region type:** <br>
Paragraph 
Marginalia 
Page number """
def merge_lines(segment_predictions):
img_lines = []
for region in segment_predictions:
img_lines += region['lines']
return img_lines
def get_text_predictions(image, segment_predictions, recognizer):
"""Collects text prediction data into dicts based on detected text regions."""
img_lines = merge_lines(segment_predictions)
height, width = segment_predictions[0]['img_shape']
# Process all lines of an image
texts = recognizer.process_lines(img_lines, image, height, width)
return texts
# Run demo code
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo:
gr.Markdown("# Multicentury HTR Demo")
gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition.
The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents
from 17th, 18th, 19th and 20th centuries.
Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window
on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs.
Best results are obtained when using high quality scans of documents with a regular layout.
Please note that this is a demo. 24/7 functionality is not quaranteed.
# Monen vuosisadan käsialantunnistus malli
Käsialantunnistus putkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin
käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1600-luvulta 1900-luvulle.
Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image*
painikkeesta ja kuva on prosessoitu tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja
tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla,
joissa on normaalin kirjan mukainen taitto.
Huom! Tämä on demo sovellus. Ympärivuorokautista toimivuutta ei luvata.
""")
with gr.Tab("Text content"):
with gr.Row():
input_img = gr.Image(label="Input image", type="pil")
textbox = gr.Textbox(label="Predicted text content", lines=10)
button = gr.Button("Process image")
processing_time = gr.Markdown()
with gr.Tab("Text regions"):
region_img = gr.Image(label="Predicted text regions", type="numpy")
gr.Markdown(color_codes)
with gr.Tab("Text lines"):
line_img = gr.Image(label="Predicted text lines", type="numpy")
gr.Markdown(color_codes)
def run_pipeline(image):
# Predict region and line segments
start = time.time()
segment_predictions = segmenter.get_segmentation(image)
print('segmentation ok')
if segment_predictions:
region_plot = plotter.plot_regions(segment_predictions, image)
line_plot = plotter.plot_lines(segment_predictions, image)
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
print('text pred ok')
text = "\n".join(text_predictions)
end = time.time()
proc_time = end - start
proc_time_str = f"Processing time: {proc_time:.4f}s"
return {
region_img: region_plot,
line_img: line_plot,
textbox: text,
processing_time: proc_time_str
}
else:
end = time.time()
proc_time = end - start
proc_time_str = f"Processing time: {proc_time:.4f}s"
return {
region_img: None,
line_img: None,
textbox: None,
processing_time: proc_time_str
}
button.click(fn=run_pipeline,
inputs=input_img,
outputs=[region_img, line_img, textbox, processing_time])
if __name__ == "__main__":
demo.queue()
demo.launch(show_error=True)
|