|
from optimum.onnxruntime import ORTModelForVision2Seq |
|
from transformers import TrOCRProcessor |
|
from huggingface_hub import login |
|
import gradio as gr |
|
import numpy as np |
|
import onnxruntime |
|
import torch |
|
import time |
|
import os |
|
|
|
from plotting_functions import PlotHTR |
|
from segment_image import SegmentImage |
|
from onnx_text_recognition import TextRecognition |
|
|
|
|
|
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection" |
|
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection" |
|
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" |
|
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx" |
|
|
|
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) |
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
def get_segmenter(): |
|
"""Initialize segmentation class.""" |
|
try: |
|
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH, |
|
device='cuda:0', |
|
line_iou=0.3, |
|
region_iou=0.5, |
|
line_overlap=0.5, |
|
line_nms_iou=0.7, |
|
region_nms_iou=0.3, |
|
line_conf_threshold=0.25, |
|
region_conf_threshold=0.5, |
|
region_model_path=REGION_MODEL_PATH, |
|
order_regions=True, |
|
region_half_precision=False, |
|
line_half_precision=False) |
|
return segmenter |
|
except Exception as e: |
|
print('Failed to initialize SegmentImage class: %s' % e) |
|
|
|
def get_recognizer(): |
|
"""Initialize text recognition class.""" |
|
try: |
|
recognizer = TextRecognition( |
|
processor_path = TROCR_PROCESSOR_PATH, |
|
model_path = TROCR_MODEL_PATH, |
|
device = 'cuda:0', |
|
half_precision = True, |
|
line_threshold = 10 |
|
) |
|
return recognizer |
|
except Exception as e: |
|
print('Failed to initialize TextRecognition class: %s' % e) |
|
|
|
segmenter = get_segmenter() |
|
recognizer = get_recognizer() |
|
plotter = PlotHTR() |
|
|
|
color_codes = """**Text region type:** <br> |
|
Paragraph  |
|
Marginalia  |
|
Page number """ |
|
|
|
def merge_lines(segment_predictions): |
|
img_lines = [] |
|
for region in segment_predictions: |
|
img_lines += region['lines'] |
|
return img_lines |
|
|
|
def get_text_predictions(image, segment_predictions, recognizer): |
|
"""Collects text prediction data into dicts based on detected text regions.""" |
|
img_lines = merge_lines(segment_predictions) |
|
height, width = segment_predictions[0]['img_shape'] |
|
|
|
texts = recognizer.process_lines(img_lines, image, height, width) |
|
return texts |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo: |
|
gr.Markdown("# Multicentury HTR Demo") |
|
gr.Markdown("""The HTR pipeline contains three components: text region detection, textline detection and handwritten text recognition. |
|
The components run machine learning models that have been trained at the National Archives of Finland using mostly handwritten documents |
|
from 17th, 18th, 19th and 20th centuries. |
|
|
|
Input image can be uploaded using the *Input image* window in the *Text content* tab, and the predicted text content will appear to the window |
|
on the right side of the image. Results of text region and text line detection can be viewed in the *Text regions* and *Text lines* tabs. |
|
Best results are obtained when using high quality scans of documents with a regular layout. |
|
|
|
Please note that this is a demo. 24/7 functionality is not quaranteed. |
|
|
|
# Monen vuosisadan käsialantunnistus malli |
|
|
|
Käsialantunnistus putkessa on kolme mallia: Tekstialueen tunnistus, tekstirivien tunnistus ja tekstintunnistus. Mallit on koulutettu pääosin |
|
käsinkirjoitetulla Kansallisarkiston aineistolla, joka ajoittuu 1600-luvulta 1900-luvulle. |
|
|
|
Tunnistettavan kuvan voi ladata *Input image* nimiseen laatikkoon *Text content* välilehdellä. Prosessointi käynnistetään *Process image* |
|
painikkeesta ja kuva on prosessoitu tunnistettu teksti ilmaantuu oikeaan laatikkoon nimeltä *Predicted text content*. Tekstialueen ja |
|
tekstirivien tunnistuksia voi tarkastella *Text regions* ja *Text lines* välilehdiltä. Parhaimman lopputuloksen saa hyvälaatuisilla kuvilla, |
|
joissa on normaalin kirjan mukainen taitto. |
|
|
|
Huom! Tämä on demo sovellus. Ympärivuorokautista toimivuutta ei luvata. |
|
""") |
|
|
|
with gr.Tab("Text content"): |
|
with gr.Row(): |
|
input_img = gr.Image(label="Input image", type="pil") |
|
textbox = gr.Textbox(label="Predicted text content", lines=10) |
|
button = gr.Button("Process image") |
|
processing_time = gr.Markdown() |
|
with gr.Tab("Text regions"): |
|
region_img = gr.Image(label="Predicted text regions", type="numpy") |
|
gr.Markdown(color_codes) |
|
with gr.Tab("Text lines"): |
|
line_img = gr.Image(label="Predicted text lines", type="numpy") |
|
gr.Markdown(color_codes) |
|
|
|
def run_pipeline(image): |
|
|
|
start = time.time() |
|
segment_predictions = segmenter.get_segmentation(image) |
|
print('segmentation ok') |
|
if segment_predictions: |
|
region_plot = plotter.plot_regions(segment_predictions, image) |
|
line_plot = plotter.plot_lines(segment_predictions, image) |
|
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer) |
|
print('text pred ok') |
|
text = "\n".join(text_predictions) |
|
end = time.time() |
|
proc_time = end - start |
|
proc_time_str = f"Processing time: {proc_time:.4f}s" |
|
return { |
|
region_img: region_plot, |
|
line_img: line_plot, |
|
textbox: text, |
|
processing_time: proc_time_str |
|
} |
|
else: |
|
end = time.time() |
|
proc_time = end - start |
|
proc_time_str = f"Processing time: {proc_time:.4f}s" |
|
return { |
|
region_img: None, |
|
line_img: None, |
|
textbox: None, |
|
processing_time: proc_time_str |
|
} |
|
|
|
|
|
button.click(fn=run_pipeline, |
|
inputs=input_img, |
|
outputs=[region_img, line_img, textbox, processing_time]) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(show_error=True) |
|
|