Spaces:
Running
on
T4
Running
on
T4
File size: 5,060 Bytes
970b7b6 29fbcb4 970b7b6 29fbcb4 970b7b6 cf00bcc 970b7b6 cf00bcc 970b7b6 7c01649 970b7b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import TrOCRProcessor
from huggingface_hub import login
from ultralytics import YOLO
import gradio as gr
import numpy as np
import onnxruntime
import time
import os
from plotting_functions import PlotHTR
from segment_image import SegmentImage
from onnx_text_recognition import TextRecognition
LINE_MODEL_PATH = "Kansallisarkisto/multicentury-textline-detection"
REGION_MODEL_PATH = "Kansallisarkisto/court-records-region-detection"
TROCR_PROCESSOR_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_processor"
TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx/202405_onnx"
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
def get_segmenter():
"""Initialize segmentation class."""
try:
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
device='cpu',
line_iou=0.3,
region_iou=0.5,
line_overlap=0.5,
line_nms_iou=0.7,
region_nms_iou=0.3,
line_conf_threshold=0.25,
region_conf_threshold=0.5,
region_model_path=REGION_MODEL_PATH,
order_regions=True,
region_half_precision=False,
line_half_precision=False)
return segmenter
except Exception as e:
print('Failed to initialize SegmentImage class: %s' % e)
def get_recognizer():
"""Initialize text recognition class."""
try:
recognizer = TextRecognition(
processor_path = TROCR_PROCESSOR_PATH,
model_path = TROCR_MODEL_PATH,
device = 'cpu',
half_precision = True,
line_threshold = 100
)
return recognizer
except Exception as e:
print('Failed to initialize TextRecognition class: %s' % e)
segmenter = get_segmenter()
recognizer = get_recognizer()
plotter = PlotHTR()
color_codes = """**Text region type:** <br>
Paragraph 
Marginalia 
Page number """
def merge_lines(segment_predictions):
img_lines = []
for region in segment_predictions:
img_lines += region['lines']
return img_lines
def get_text_predictions(image, segment_predictions, recognizer):
"""Collects text prediction data into dicts based on detected text regions."""
img_lines = merge_lines(segment_predictions)
height, width = segment_predictions[0]['img_shape']
# Process all lines of an image
texts = recognizer.process_lines(img_lines, image, height, width)
return texts
# Run demo code
with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as demo:
gr.Markdown("# Multicentury HTR Demo")
with gr.Tab("Text content"):
with gr.Row():
input_img = gr.Image(label="Input image", type="pil")
textbox = gr.Textbox(label="Predicted text content", lines=10)
button = gr.Button("Process image")
processing_time = gr.Markdown()
with gr.Tab("Text regions"):
region_img = gr.Image(label="Predicted text regions", type="numpy")
gr.Markdown(color_codes)
with gr.Tab("Text lines"):
line_img = gr.Image(label="Predicted text lines", type="numpy")
gr.Markdown(color_codes)
def run_pipeline(image):
# Predict region and line segments
start = time.time()
segment_predictions = segmenter.get_segmentation(image)
if segment_predictions:
region_plot = plotter.plot_regions(segment_predictions, image)
line_plot = plotter.plot_lines(segment_predictions, image)
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
text = "\n".join(text_predictions)
end = time.time()
proc_time = end - start
proc_time_str = f"Processing time: {proc_time:.4f}s"
return {
region_img: region_plot,
line_img: line_plot,
textbox: text,
processing_time: proc_time_str
}
else:
end = time.time()
proc_time = end - start
proc_time_str = f"Processing time: {proc_time:.4f}s"
return {
region_img: None,
line_img: None,
textbox: None,
processing_time: proc_time_str
}
button.click(fn=run_pipeline,
inputs=input_img,
outputs=[region_img, line_img, textbox, processing_time])
if __name__ == "__main__":
demo.launch()
|