Spaces:
Running
on
T4
Running
on
T4
File size: 15,699 Bytes
2311a8c f8a998a efebdb3 f8a998a b32227d f8a998a 2311a8c 3351c8b f8a998a 2311a8c 3351c8b f8a998a 2311a8c f8a998a 2311a8c f8a998a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
from huggingface_hub import hf_hub_download
from shapely.validation import make_valid
from shapely.geometry import Polygon
from ultralytics import YOLO
from PIL import Image
import numpy as np
import os
from reading_order import OrderPolygons
class SegmentImage:
"""Class for segmenting document image regions and text lines."""
def __init__(self,
line_model_path,
device,
line_iou=0.5,
region_iou=0.5,
line_overlap=0.5,
line_nms_iou=0.7,
region_nms_iou=0.3,
line_conf_threshold=0.25,
region_conf_threshold=0.25,
region_model_path=None,
order_regions=True,
region_half_precision=False,
line_half_precision=False):
# Path to text line detection model
self.line_model_path = line_model_path
# Path to text region detection model
self.region_model_path = region_model_path
# Defines the IoU threshold used in the non-maximum suppression (NMS) process to
# determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes
self.line_nms_iou = line_nms_iou
self.region_nms_iou = region_nms_iou
# Defines the IoU threshold for text lines
self.line_iou = line_iou
# Defines the IoU threshold for text regions
self.region_iou = region_iou
# Defines the extent of line polygon overlap used for merging the polygons
self.line_overlap = line_overlap
# Defines confidence threshold for line detection
self.line_conf_threshold = line_conf_threshold
# Defines confidence threshold for region detection
self.region_conf_threshold = region_conf_threshold
# Defines the device to be used ('cpu', gpu '0', gpu '1' etc.)
self.device = device
# Defines whether a reading order is also estimated for the region detections
self.order_regions = order_regions
# Defines whether half precision (FP16) is used by the region and line prediction models
self.region_half_precision = region_half_precision
self.line_half_precision = line_half_precision
self.order_poly = OrderPolygons()
# Initialize segmentation model(s)
self.line_model = self.init_line_model()
if self.region_model_path:
self.region_model = self.init_region_model()
def init_line_model(self):
"""Function for initializing the line detection model."""
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt")
line_model = YOLO(cached_model_path)
return line_model
except Exception as e:
print('Failed to load the line detection model: %s' % e)
def init_region_model(self):
"""Function for initializing the region detection model."""
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt")
region_model = YOLO(cached_model_path)
return region_model
except Exception as e:
print('Failed to load the region detection model: %s' % e)
def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape):
"""Function for creating unique id for each detected region."""
n = min(len(classes), len(coords))
res = []
for i in range(n):
# Creates a simple index-based id for each region
region_id = str(i)
# Extracts region name corresponding to the index
region_type = names[classes[i]]
poly_dict = {'coords': coords[i],
'max_min': max_min[i],
'class': str(classes[i]),
'name': region_type,
'conf': box_confs[i],
'id': region_id,
'img_shape': img_shape}
res.append(poly_dict)
return res
def get_max_min(self, polygons):
"""Creates an array with the minimum and maximum
x and y values of the input polygons."""
n_rows = len(polygons)
xy_array = np.zeros([n_rows, 4])
for i, poly in enumerate(polygons):
x = [point[0] for point in poly]
y = [point[1] for point in poly]
if x:
xy_array[i,0] = max(x)
xy_array[i,1] = min(x)
if y:
xy_array[i,2] = max(y)
xy_array[i,3] = min(y)
return xy_array
def validate_polygon(self, polygon):
""""Function for testing and correcting the validity of polygons."""
if len(polygon) > 2:
polygon = Polygon(polygon)
if not polygon.is_valid:
polygon = make_valid(polygon)
return polygon
else:
return None
def get_iou(self, poly1, poly2):
"""Function for calculating Intersection over Union (IoU) values."""
# If the polygons don't intersect, IoU is 0
iou = 0
poly1 = self.validate_polygon(poly1)
poly2 = self.validate_polygon(poly2)
if poly1 and poly2:
if poly1.intersects(poly2):
# Calculates intersection of the 2 polygons
intersect = poly1.intersection(poly2).area
# Calculates union of the 2 polygons
uni = poly1.union(poly2)
# Calculates intersection over union
iou = intersect / uni.area
return iou
def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None):
"""Merges polygons that have an IoU value
above the given threshold."""
new_polygons = []
dropped = set()
# Loops over all input polygons and merges them if the
# IoU value is over the given threshold
for i in range(0, len(polygons)):
poly1 = self.validate_polygon(polygons[i])
merged = None
for j in range(i+1, len(polygons)):
poly2 = self.validate_polygon(polygons[j])
if poly1 and poly2:
if poly1.intersects(poly2):
overlap = False
intersect = poly1.intersection(poly2)
uni = poly1.union(poly2)
# Calculates intersection over union
iou = intersect.area / uni.area
if overlap_threshold:
overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area))
if (iou > iou_threshold) or overlap:
if merged:
# If there are multiple overlapping polygons
# with IoU over the threshold, they are all merged together
merged = uni.union(merged)
dropped.add(j)
else:
merged = uni
# Polygons that are merged together are dropped from
# the list
dropped.add(i)
dropped.add(j)
if merged:
if merged.geom_type in ['GeometryCollection','MultiPolygon']:
for geom in merged.geoms:
if geom.geom_type == 'Polygon':
new_polygons.append(list(geom.exterior.coords))
elif merged.geom_type == 'Polygon':
new_polygons.append(list(merged.exterior.coords))
res = [i for j, i in enumerate(polygons) if j not in dropped]
res += new_polygons
return res
def get_region_preds(self, img):
"""Function for predicting text region coordinates."""
results = self.region_model.predict(source=img,
device=self.device,
conf=self.region_conf_threshold,
half=bool(self.region_half_precision),
iou=self.region_nms_iou)
results = results[0].cpu()
if results.masks:
# Extracts detected region polygons
coords = results.masks.xy
# Merge overlapping polygons
coords = self.merge_polygons(coords, self.region_iou)
# Maximum and minimum x and y axis values for detected polygons used for ordering the polygons
max_min = self.get_max_min(coords).tolist()
# Gets a list of the predicted class labels for detected regions
classes = results.boxes.cls.tolist()
# A dictionary with class ids as keys and class names as values
names = results.names
# Confidence values for detections
box_confs = results.boxes.conf.tolist()
# A tuple containing the shape of the original image
img_shape = results.orig_shape
res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape)
return res
else:
return None
def get_line_preds(self, img):
"""Function for predicting text line coordinates."""
results = self.line_model.predict(source=img,
device=self.device,
conf=self.line_conf_threshold,
half=bool(self.line_half_precision),
iou=self.line_nms_iou)
results = results[0].cpu()
if results.masks:
# Detected text line polygons
coords = results.masks.xy
# Merge overlapping polygons
coords = self.merge_polygons(coords, self.line_iou, self.line_overlap)
# Maximum and minimum x and y axis values for detected polygons
max_min = self.get_max_min(coords).tolist()
# Confidence values for detections
box_confs = results.boxes.conf.tolist()
res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs}
return res_dict
else:
return None
def get_dist(self, line_polygon, regions):
"""Function for finding the closest region to the text line."""
dist, reg_id = 1000000, None
line_polygon = self.validate_polygon(line_polygon)
if line_polygon:
for region in regions:
# Calculates dictance between line and regions polygons
region_polygon = self.validate_polygon(region['coords'])
if region_polygon:
line_reg_dist = line_polygon.distance(region_polygon)
if line_reg_dist < dist:
dist = line_reg_dist
reg_id = region['id']
return reg_id
def get_line_regions(self, lines, regions):
"""Function for connecting each text line to one region."""
lines_list = []
for i in range(len(lines['coords'])):
iou, reg_id, conf = 0, '', 0.0
max_min = [0.0, 0.0, 0.0, 0.0]
polygon = lines['coords'][i]
for region in regions:
line_reg_iou = self.get_iou(polygon, region['coords'])
if line_reg_iou > iou:
iou = line_reg_iou
reg_id = region['id']
# If line polygon does not intersect with any region, a distance metric is used for defining
# the region that the line belongs to
if iou == 0:
reg_id = self.get_dist(polygon, regions)
if (len(lines['max_min']) - 1) >= i:
max_min = lines['max_min'][i]
if (len(lines['confs']) - 1) >= i:
conf = lines['confs'][i]
new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf}
lines_list.append(new_line)
return lines_list
def order_regions_lines(self, lines, regions):
"""Function for ordering line predictions inside each region."""
regions_with_rows = []
region_max_mins = []
for i, region in enumerate(regions):
line_max_mins = []
line_confs = []
line_polygons = []
for line in lines:
if line['reg_id'] == region['id']:
line_max_mins.append(line['max_min'])
line_confs.append(line['conf'])
line_polygons.append(line['polygon'])
if line_polygons:
# If one or more lines are connected to a region, line order inside the region is defined
# and the predicted text lines are joined in the same python dict
line_order = self.order_poly.order(line_max_mins)
line_polygons = [line_polygons[i] for i in line_order]
line_confs = [line_confs[i] for i in line_order]
new_region = {'region_coords': region['coords'],
'region_name': region['name'],
'lines': line_polygons,
'line_confs': line_confs,
'region_conf': region['conf'],
'img_shape': region['img_shape']}
region_max_mins.append(region['max_min'])
regions_with_rows.append(new_region)
else:
continue
# Creates an ordering of the detected regions based on their polygon coordinates
if self.order_regions:
region_order = self.order_poly.order(region_max_mins)
regions_with_rows = [regions_with_rows[i] for i in region_order]
return regions_with_rows
def get_default_region(self, image):
"""Function for creating a default region if no regions are detected."""
w, h = image.size
region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]],
'max_min': [w, 0.0, h, 0.0],
'class': '0',
'name': "paragraph",
'conf': 0.0,
'id': '0',
'img_shape': (h, w)}
return [region]
def get_segmentation(self, image):
"""Segment input image into ordered text lines or ordered text regions and text lines."""
line_preds = self.get_line_preds(image)
if line_preds:
# If region detection model is defined, text regions and text lines are detected
region_preds = self.get_region_preds(image)
if not region_preds:
region_preds = self.get_default_region(image)
print(f'No regions detected from image {image}')
lines_with_regions = self.get_line_regions(line_preds, region_preds)
ordered_regions = self.order_regions_lines(lines_with_regions, region_preds)
return ordered_regions
else:
print(f'No text lines detected from image {image}')
return None
|