File size: 15,699 Bytes
2311a8c
f8a998a
 
efebdb3
f8a998a
 
b32227d
f8a998a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2311a8c
3351c8b
f8a998a
 
 
 
 
 
 
 
2311a8c
3351c8b
f8a998a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2311a8c
 
 
 
 
f8a998a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2311a8c
 
 
 
 
f8a998a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
from huggingface_hub import hf_hub_download
from shapely.validation import make_valid
from shapely.geometry import Polygon
from ultralytics import YOLO
from PIL import Image
import numpy as np
import os

from reading_order import OrderPolygons

class SegmentImage:
    """Class for segmenting document image regions and text lines."""
    def __init__(self, 
                line_model_path, 
                device, 
                line_iou=0.5,
                region_iou=0.5,
                line_overlap=0.5, 
                line_nms_iou=0.7,
                region_nms_iou=0.3, 
                line_conf_threshold=0.25, 
                region_conf_threshold=0.25, 
                region_model_path=None, 
                order_regions=True, 
                region_half_precision=False, 
                line_half_precision=False):

        # Path to text line detection model
        self.line_model_path = line_model_path 
        # Path to text region detection model
        self.region_model_path = region_model_path 
        # Defines the IoU threshold used in the non-maximum suppression (NMS) process to 
        # determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes
        self.line_nms_iou = line_nms_iou
        self.region_nms_iou = region_nms_iou
        # Defines the IoU threshold for text lines
        self.line_iou = line_iou  
        # Defines the IoU threshold for text regions
        self.region_iou = region_iou 
        # Defines the extent of line polygon overlap used for merging the polygons
        self.line_overlap = line_overlap  
        # Defines confidence threshold for line detection
        self.line_conf_threshold = line_conf_threshold 
        # Defines confidence threshold for region detection
        self.region_conf_threshold = region_conf_threshold 
        # Defines the device to be used ('cpu', gpu '0', gpu '1' etc.)
        self.device = device 
        # Defines whether a reading order is also estimated for the region detections
        self.order_regions = order_regions 
        # Defines whether half precision (FP16) is used by the region and line prediction models
        self.region_half_precision = region_half_precision 
        self.line_half_precision = line_half_precision 
        self.order_poly = OrderPolygons()
        # Initialize segmentation model(s)
        self.line_model = self.init_line_model()
        if self.region_model_path:
            self.region_model = self.init_region_model()

    def init_line_model(self):
        """Function for initializing the line detection model."""
        try:
            # Load the trained line detection model
            cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt")
            line_model = YOLO(cached_model_path)
            return line_model
        except Exception as e:
            print('Failed to load the line detection model: %s' % e)

    def init_region_model(self):
        """Function for initializing the region detection model."""
        try:
            # Load the trained line detection model
            cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt")
            region_model = YOLO(cached_model_path)
            return region_model
        except Exception as e:
            print('Failed to load the region detection model: %s' % e)
        
    def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape):
        """Function for creating unique id for each detected region."""
        n = min(len(classes), len(coords))
        res = []
        for i in range(n):
            # Creates a simple index-based id for each region
            region_id = str(i)
            # Extracts region name corresponding to the index
            region_type = names[classes[i]] 
            poly_dict = {'coords': coords[i], 
                        'max_min': max_min[i], 
                        'class': str(classes[i]), 
                        'name': region_type, 
                        'conf': box_confs[i],
                        'id': region_id, 
                        'img_shape': img_shape}
            res.append(poly_dict)
        return res

    def get_max_min(self, polygons):
        """Creates an array with the minimum and maximum 
        x and y values of the input polygons."""
        n_rows = len(polygons)
        xy_array = np.zeros([n_rows, 4])
        for i, poly in enumerate(polygons):
            x = [point[0] for point in poly]
            y = [point[1] for point in poly]
            if x:
                xy_array[i,0] = max(x)
                xy_array[i,1] = min(x)
            if y:
                xy_array[i,2] = max(y)
                xy_array[i,3] = min(y)
        return xy_array

    def validate_polygon(self, polygon):
        """"Function for testing and correcting the validity of polygons."""
        if len(polygon) > 2:
            polygon = Polygon(polygon)
            if not polygon.is_valid:
                polygon = make_valid(polygon)
            return polygon
        else:
            return None

    def get_iou(self, poly1, poly2):
        """Function for calculating Intersection over Union (IoU) values."""
        # If the polygons don't intersect, IoU is 0
        iou = 0
        poly1 = self.validate_polygon(poly1)
        poly2 = self.validate_polygon(poly2)

        if poly1 and poly2:
            if poly1.intersects(poly2):
                # Calculates intersection of the 2 polygons
                intersect = poly1.intersection(poly2).area
                # Calculates union of the 2 polygons
                uni = poly1.union(poly2)
                # Calculates intersection over union
                iou = intersect / uni.area
        return iou

    def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None):
        """Merges polygons that have an IoU value 
        above the given threshold."""
        new_polygons = []
        dropped = set()
        # Loops over all input polygons and merges them if the
        # IoU value is over the given threshold
        for i in range(0, len(polygons)):
            poly1 = self.validate_polygon(polygons[i])
            merged = None
            for j in range(i+1, len(polygons)):
                poly2 = self.validate_polygon(polygons[j])
                if poly1 and poly2: 
                    if poly1.intersects(poly2):
                        overlap = False
                        intersect = poly1.intersection(poly2)
                        uni = poly1.union(poly2)
                        # Calculates intersection over union
                        iou = intersect.area / uni.area
                        if overlap_threshold:
                            overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area))
                        if (iou > iou_threshold) or overlap:
                            if merged:
                                # If there are multiple overlapping polygons
                                # with IoU over the threshold, they are all merged together
                                merged = uni.union(merged)
                                dropped.add(j)
                            else:
                                merged = uni
                                # Polygons that are merged together are dropped from
                                # the list
                                dropped.add(i)
                                dropped.add(j)       
            if merged:
                if merged.geom_type in ['GeometryCollection','MultiPolygon']:
                    for geom in merged.geoms:                
                        if geom.geom_type == 'Polygon':
                            new_polygons.append(list(geom.exterior.coords))
                elif merged.geom_type == 'Polygon':
                    new_polygons.append(list(merged.exterior.coords))
        res = [i for j, i in enumerate(polygons) if j not in dropped]
        res += new_polygons
        
        return res

    def get_region_preds(self, img):
        """Function for predicting text region coordinates."""
        results = self.region_model.predict(source=img,     
                                            device=self.device, 
                                            conf=self.region_conf_threshold, 
                                            half=bool(self.region_half_precision), 
                                            iou=self.region_nms_iou)
        results = results[0].cpu()
        if results.masks:
            # Extracts detected region polygons
            coords = results.masks.xy
            # Merge overlapping polygons
            coords = self.merge_polygons(coords, self.region_iou)
            # Maximum and minimum x and y axis values for detected polygons used for ordering the polygons
            max_min = self.get_max_min(coords).tolist() 
            # Gets a list of the predicted class labels for detected regions
            classes = results.boxes.cls.tolist() 
            # A dictionary with class ids as keys and class names as values
            names = results.names 
            # Confidence values for detections
            box_confs = results.boxes.conf.tolist()
            # A tuple containing the shape of the original image
            img_shape = results.orig_shape 
            res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape)
            return res
        else:
            return None


    def get_line_preds(self, img):
        """Function for predicting text line coordinates."""
        results = self.line_model.predict(source=img, 
                                          device=self.device, 
                                          conf=self.line_conf_threshold, 
                                          half=bool(self.line_half_precision),
                                          iou=self.line_nms_iou)
        results = results[0].cpu()
        if results.masks:
            # Detected text line polygons 
            coords = results.masks.xy
            # Merge overlapping polygons
            coords = self.merge_polygons(coords, self.line_iou, self.line_overlap)
            # Maximum and minimum x and y axis values for detected polygons
            max_min = self.get_max_min(coords).tolist()
            # Confidence values for detections
            box_confs = results.boxes.conf.tolist()
            res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs}
            return res_dict
        else:
            return None

    def get_dist(self, line_polygon, regions):
        """Function for finding the closest region to the text line."""
        dist, reg_id = 1000000, None
        line_polygon = self.validate_polygon(line_polygon)

        if line_polygon:
            for region in regions:
                # Calculates dictance between line and regions polygons
                region_polygon = self.validate_polygon(region['coords'])
                if region_polygon:
                    line_reg_dist = line_polygon.distance(region_polygon)
                    if line_reg_dist < dist:
                        dist = line_reg_dist
                        reg_id = region['id']
        return reg_id
    
    def get_line_regions(self, lines, regions):
        """Function for connecting each text line to one region."""
        lines_list = []
        for i in range(len(lines['coords'])):
            iou, reg_id, conf = 0, '', 0.0
            max_min = [0.0, 0.0, 0.0, 0.0]
            polygon = lines['coords'][i]
            for region in regions:
                line_reg_iou = self.get_iou(polygon, region['coords']) 
                if line_reg_iou > iou:
                    iou = line_reg_iou
                    reg_id = region['id']
            # If line polygon does not intersect with any region, a distance metric is used for defining 
            # the region that the line belongs to
            if iou == 0:
                reg_id = self.get_dist(polygon, regions)

            if (len(lines['max_min']) - 1) >= i:
                max_min = lines['max_min'][i]
                
            if (len(lines['confs']) - 1) >= i:
                conf = lines['confs'][i]

            new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf}
            lines_list.append(new_line)
        return lines_list

    def order_regions_lines(self, lines, regions):
        """Function for ordering line predictions inside each region."""
        regions_with_rows = []
        region_max_mins = []
        for i, region in enumerate(regions):
            line_max_mins = []
            line_confs = []
            line_polygons = []
            for line in lines:
                if line['reg_id'] == region['id']:
                    line_max_mins.append(line['max_min'])
                    line_confs.append(line['conf'])
                    line_polygons.append(line['polygon'])
            if line_polygons:
                # If one or more lines are connected to a region, line order inside the region is defined
                # and the predicted text lines are joined in the same python dict
                line_order = self.order_poly.order(line_max_mins)
                line_polygons = [line_polygons[i] for i in line_order]
                line_confs = [line_confs[i] for i in line_order]
                new_region = {'region_coords': region['coords'], 
                            'region_name': region['name'], 
                            'lines': line_polygons, 
                            'line_confs': line_confs,
                            'region_conf': region['conf'],
                            'img_shape': region['img_shape']}
                region_max_mins.append(region['max_min'])
                regions_with_rows.append(new_region)
            else:
                continue
        # Creates an ordering of the detected regions based on their polygon coordinates
        if self.order_regions:
            region_order = self.order_poly.order(region_max_mins)
            regions_with_rows = [regions_with_rows[i] for i in region_order]
            
        return regions_with_rows

    def get_default_region(self, image):
        """Function for creating a default region if no regions are detected."""
        w, h = image.size 
        region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]], 
                        'max_min': [w, 0.0, h, 0.0], 
                        'class': '0', 
                        'name': "paragraph", 
                        'conf': 0.0,
                        'id': '0', 
                        'img_shape': (h, w)}
        return [region]

    def get_segmentation(self, image):
        """Segment input image into ordered text lines or ordered text regions and text lines."""
        line_preds = self.get_line_preds(image)
        if line_preds:
            # If region detection model is defined, text regions and text lines are detected
            region_preds = self.get_region_preds(image)
            if not region_preds:
                region_preds = self.get_default_region(image)
                print(f'No regions detected from image {image}')
            lines_with_regions = self.get_line_regions(line_preds, region_preds)
            ordered_regions = self.order_regions_lines(lines_with_regions, region_preds)
            return ordered_regions
        else:
            print(f'No text lines detected from image {image}')
            return None