Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from huggingface_hub import login
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import onnxruntime
|
|
|
7 |
import time
|
8 |
import os
|
9 |
|
@@ -19,11 +20,14 @@ TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"
|
|
19 |
|
20 |
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
|
21 |
|
|
|
|
|
|
|
22 |
def get_segmenter():
|
23 |
"""Initialize segmentation class."""
|
24 |
try:
|
25 |
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
|
26 |
-
device='
|
27 |
line_iou=0.3,
|
28 |
region_iou=0.5,
|
29 |
line_overlap=0.5,
|
@@ -45,7 +49,7 @@ def get_recognizer():
|
|
45 |
recognizer = TextRecognition(
|
46 |
processor_path = TROCR_PROCESSOR_PATH,
|
47 |
model_path = TROCR_MODEL_PATH,
|
48 |
-
device = '
|
49 |
half_precision = True,
|
50 |
line_threshold = 10
|
51 |
)
|
@@ -99,9 +103,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as d
|
|
99 |
print('segmentation ok')
|
100 |
if segment_predictions:
|
101 |
region_plot = plotter.plot_regions(segment_predictions, image)
|
102 |
-
print('region plot ok')
|
103 |
line_plot = plotter.plot_lines(segment_predictions, image)
|
104 |
-
print('line plot ok')
|
105 |
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
|
106 |
print('text pred ok')
|
107 |
text = "\n".join(text_predictions)
|
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import onnxruntime
|
7 |
+
import torch
|
8 |
import time
|
9 |
import os
|
10 |
|
|
|
20 |
|
21 |
login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
|
22 |
|
23 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
24 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
25 |
+
|
26 |
def get_segmenter():
|
27 |
"""Initialize segmentation class."""
|
28 |
try:
|
29 |
segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
|
30 |
+
device='cuda:0',
|
31 |
line_iou=0.3,
|
32 |
region_iou=0.5,
|
33 |
line_overlap=0.5,
|
|
|
49 |
recognizer = TextRecognition(
|
50 |
processor_path = TROCR_PROCESSOR_PATH,
|
51 |
model_path = TROCR_MODEL_PATH,
|
52 |
+
device = '0',
|
53 |
half_precision = True,
|
54 |
line_threshold = 10
|
55 |
)
|
|
|
103 |
print('segmentation ok')
|
104 |
if segment_predictions:
|
105 |
region_plot = plotter.plot_regions(segment_predictions, image)
|
|
|
106 |
line_plot = plotter.plot_lines(segment_predictions, image)
|
|
|
107 |
text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
|
108 |
print('text pred ok')
|
109 |
text = "\n".join(text_predictions)
|