MikkoLipsanen commited on
Commit
07944e2
·
verified ·
1 Parent(s): d8a5332

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -4,6 +4,7 @@ from huggingface_hub import login
4
  import gradio as gr
5
  import numpy as np
6
  import onnxruntime
 
7
  import time
8
  import os
9
 
@@ -19,11 +20,14 @@ TROCR_MODEL_PATH = "Kansallisarkisto/multicentury-htr-model-onnx"
19
 
20
  login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
21
 
 
 
 
22
  def get_segmenter():
23
  """Initialize segmentation class."""
24
  try:
25
  segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
26
- device='cpu',
27
  line_iou=0.3,
28
  region_iou=0.5,
29
  line_overlap=0.5,
@@ -45,7 +49,7 @@ def get_recognizer():
45
  recognizer = TextRecognition(
46
  processor_path = TROCR_PROCESSOR_PATH,
47
  model_path = TROCR_MODEL_PATH,
48
- device = 'cpu',
49
  half_precision = True,
50
  line_threshold = 10
51
  )
@@ -99,9 +103,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), title="Multicentury HTR Demo") as d
99
  print('segmentation ok')
100
  if segment_predictions:
101
  region_plot = plotter.plot_regions(segment_predictions, image)
102
- print('region plot ok')
103
  line_plot = plotter.plot_lines(segment_predictions, image)
104
- print('line plot ok')
105
  text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
106
  print('text pred ok')
107
  text = "\n".join(text_predictions)
 
4
  import gradio as gr
5
  import numpy as np
6
  import onnxruntime
7
+ import torch
8
  import time
9
  import os
10
 
 
20
 
21
  login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)
22
 
23
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
24
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
25
+
26
  def get_segmenter():
27
  """Initialize segmentation class."""
28
  try:
29
  segmenter = SegmentImage(line_model_path=LINE_MODEL_PATH,
30
+ device='cuda:0',
31
  line_iou=0.3,
32
  region_iou=0.5,
33
  line_overlap=0.5,
 
49
  recognizer = TextRecognition(
50
  processor_path = TROCR_PROCESSOR_PATH,
51
  model_path = TROCR_MODEL_PATH,
52
+ device = '0',
53
  half_precision = True,
54
  line_threshold = 10
55
  )
 
103
  print('segmentation ok')
104
  if segment_predictions:
105
  region_plot = plotter.plot_regions(segment_predictions, image)
 
106
  line_plot = plotter.plot_lines(segment_predictions, image)
 
107
  text_predictions = get_text_predictions(np.array(image), segment_predictions, recognizer)
108
  print('text pred ok')
109
  text = "\n".join(text_predictions)