from optimum.onnxruntime import ORTModelForVision2Seq |
from transformers import TrOCRProcessor |
from huggingface_hub import login |
import gradio as gr |
import numpy as np |
import onnxruntime |
import torch |
import time |
import os |
from plotting_functions import PlotHTR |
from segment_image import SegmentImage |
from onnx_text_recognition import TextRecognition |
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" |
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection" |
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" |
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" |
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) |
print(f"Is CUDA available: {torch.cuda.is_available()}") |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
def get_segmenter(): |
"""Initialize segmentation class.""" |
try: |
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, |
device='cuda:0', |
line_iou=0.3, |
region_iou=0.5, |
line_overlap=0.5, |
line_nms_iou=0.7, |
region_nms_iou=0.3, |
line_conf_threshold=0.25, |
region_conf_threshold=0.5, |
region_model_path=REGION_MODEL_PATH, |
order_regions=True, |
region_half_precision=False, |
line_half_precision=False) |
return segmenter |
except Exception as e: |
print('Failed to initialize SegmentImage class: %s' % e) |
def get_recognizer(): |
"""Initialize text recognition class.""" |
try: |
recognizer = TextRecognition( |
processor_path = TROCR_PROCESSOR_PATH, |
model_path = TROCR_MODEL_PATH, |
device = 'cuda:0', |
half_precision = True, |
line_threshold = 10 |
) |
return recognizer |
except Exception as e: |
print('Failed to initialize TextRecognition class: %s' % e) |
segmenter = get_segmenter() |
recognizer = get_recognizer() |
plotter = PlotHTR() |
color_codes = """**Text region type:** <br> |
Paragraph  |
Marginalia  |
Page number """ |
def merge_lines(segment_predictions): |
img_lines = [] |
for region in segment_predictions: |
img_lines += region['lines'] |
return img_lines |
def get_text_predictions(image, segment_predictions, recognizer): |
"""Collects text prediction data into dicts based on detected text regions.""" |
img_lines = merge_lines(segment_predictions) |
height, width = segment_predictions[0]['img_shape'] |
texts = recognizer.process_lines(img_lines, image, height, width) |
return texts |
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo: |
gr.Markdown("# Multicentury HTR Demo") |
gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition. |
The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents |
from 17th, 18th, 19th and 20th centuries. |
Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window |
on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs. |
Best results are obtained when using high quality scans of documents with a regular layout. |
Please note that this is a demo. 24/7 functionality is not quaranteed. |
# Monen vuosisadan käsialantunnistus malli |
Käsialantunnistus putkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin |
käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1600-luvulta 1900-luvulle. |
Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image* |
painikkeesta ja kuva on prosessoitu tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja |
tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla, |
joissa on normaalin kirjan mukainen taitto. |
Huom! Tämä on demo sovellus. Ympärivuorokautista toimivuutta ei luvata. |
""") |
with gr.Tab("Text content"): |
with gr.Row(): |
input_img = gr.Image(label="Input image", type="pil") |
textbox = gr.Textbox(label="Predicted text content", lines=10) |
button = gr.Button("Process image") |
processing_time = gr.Markdown() |
with gr.Tab("Text regions"): |
region_img = gr.Image(label="Predicted text regions", type="numpy") |
gr.Markdown(color_codes) |
with gr.Tab("Text lines"): |
line_img = gr.Image(label="Predicted text lines", type="numpy") |
gr.Markdown(color_codes) |
def run_pipeline(image): |
start = time.time() |
segment_predictions = segmenter.get_segmentation(image) |
print('segmentation ok') |
if segment_predictions: |
region_plot = plotter.plot_regions(segment_predictions, image) |
line_plot = plotter.plot_lines(segment_predictions, image) |
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer) |
print('text pred ok') |
text = "\n".join(text_predictions) |
end = time.time() |
proc_time = end - start |
proc_time_str = f"Processing time: {proc_time:.4f}s" |
return { |
region_img: region_plot, |
line_img: line_plot, |
textbox: text, |
processing_time: proc_time_str |
} |
else: |
end = time.time() |
proc_time = end - start |
proc_time_str = f"Processing time: {proc_time:.4f}s" |
return { |
region_img: None, |
line_img: None, |
textbox: None, |
processing_time: proc_time_str |
} |
button.click(fn=run_pipeline, |
inputs=input_img, |
outputs=[region_img, line_img, textbox, processing_time]) |
if __name__ == "__main__": |
demo.queue() |
demo.launch(show_error=True) |