MikkoLipsanen commited on
Commit
2311a8c
·
verified ·
1 Parent(s): f115adf

Update segment_image.py

Browse files
Files changed (1) hide show
  1. segment_image.py +15 -12
segment_image.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from shapely.validation import make_valid
2
  from shapely.geometry import Polygon
3
  from ultralyticsplus import YOLO
@@ -59,7 +60,8 @@ class SegmentImage:
59
  """Function for initializing the line detection model."""
60
  try:
61
  # Load the trained line detection model
62
- line_model = YOLO(self.line_model_path, hf_token=os.getenv("HF_TOKEN"))
 
63
  return line_model
64
  except Exception as e:
65
  print('Failed to load the line detection model: %s' % e)
@@ -68,7 +70,8 @@ class SegmentImage:
68
  """Function for initializing the region detection model."""
69
  try:
70
  # Load the trained line detection model
71
- region_model = YOLO(self.region_model_path, hf_token=os.getenv("HF_TOKEN"))
 
72
  return region_model
73
  except Exception as e:
74
  print('Failed to load the region detection model: %s' % e)
@@ -182,11 +185,11 @@ class SegmentImage:
182
 
183
  def get_region_preds(self, img):
184
  """Function for predicting text region coordinates."""
185
- results = self.region_model(source=img,
186
- device=self.device,
187
- conf=self.region_conf_threshold,
188
- half=bool(self.region_half_precision),
189
- iou=self.region_nms_iou)
190
  results = results[0].cpu()
191
  if results.masks:
192
  # Extracts detected region polygons
@@ -211,11 +214,11 @@ class SegmentImage:
211
 
212
  def get_line_preds(self, img):
213
  """Function for predicting text line coordinates."""
214
- results = self.line_model(source=img,
215
- device=self.device,
216
- conf=self.line_conf_threshold,
217
- half=bool(self.line_half_precision),
218
- iou=self.line_nms_iou)
219
  results = results[0].cpu()
220
  if results.masks:
221
  # Detected text line polygons
 
1
+ from huggingface_hub import hf_hub_download
2
  from shapely.validation import make_valid
3
  from shapely.geometry import Polygon
4
  from ultralyticsplus import YOLO
 
60
  """Function for initializing the line detection model."""
61
  try:
62
  # Load the trained line detection model
63
+ cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt")
64
+ line_model = YOLO(cached_model_path, hf_token=os.getenv("HF_TOKEN"))
65
  return line_model
66
  except Exception as e:
67
  print('Failed to load the line detection model: %s' % e)
 
70
  """Function for initializing the region detection model."""
71
  try:
72
  # Load the trained line detection model
73
+ cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt")
74
+ region_model = YOLO(cached_model_path, hf_token=os.getenv("HF_TOKEN"))
75
  return region_model
76
  except Exception as e:
77
  print('Failed to load the region detection model: %s' % e)
 
185
 
186
  def get_region_preds(self, img):
187
  """Function for predicting text region coordinates."""
188
+ results = self.region_model.predict(source=img,
189
+ device=self.device,
190
+ conf=self.region_conf_threshold,
191
+ half=bool(self.region_half_precision),
192
+ iou=self.region_nms_iou)
193
  results = results[0].cpu()
194
  if results.masks:
195
  # Extracts detected region polygons
 
214
 
215
  def get_line_preds(self, img):
216
  """Function for predicting text line coordinates."""
217
+ results = self.line_model.predict(source=img,
218
+ device=self.device,
219
+ conf=self.line_conf_threshold,
220
+ half=bool(self.line_half_precision),
221
+ iou=self.line_nms_iou)
222
  results = results[0].cpu()
223
  if results.masks:
224
  # Detected text line polygons