Devesh Pant commited on
Commit
1b870f4
1 Parent(s): 589dd63
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Newspapers OCR Demo
3
- emoji: 👀
4
- colorFrom: purple
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.29.0
 
1
  ---
2
+ title: Newspaper Demo
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.29.0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import streamlit as st
4
+ from run_yolo import get_layout_results
5
+ from order_text_blocks import get_ordered_data
6
+ from run_ocr import OCR
7
+ from main import driver
8
+ import json
9
+ import pandas as pd
10
+
11
+ colors = {
12
+ 'Articles': [0, 0, 0], # Red
13
+ 'Advertisement': [0, 255, 0], # Green
14
+ 'Headlines': [0, 0, 255], # Blue
15
+ 'Sub-headlines': [255, 255, 0], # Yellow
16
+ 'Graphics': [255, 0, 255], # Magenta
17
+ 'Images': [128, 0, 128], # Purple
18
+ 'Tables': [0, 255, 255], # Teal
19
+ 'Header': [0, 0, 0], # Black
20
+ 'Text Block': [255, 0, 0]
21
+ }
22
+
23
+ try:
24
+ st.set_page_config(layout="wide", page_title="Newspaper Layout Detection and OCR Demo")
25
+ st.markdown("<h1 style='text-align: center; color: #333;'>Newspaper Layout Detection and OCR Demo</h1>", unsafe_allow_html=True)
26
+
27
+ # Streamlit UI for user input
28
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
29
+ language_name = st.text_input("Enter the language name: (hin, en, mal, tel, tam, kan))")
30
+ submit_button = st.button("Submit")
31
+
32
+ # Check if the user clicked the submit button
33
+ if submit_button:
34
+ # Check if image and language are provided
35
+ if uploaded_image is not None and language_name:
36
+ # Convert Streamlit file uploader to OpenCV image
37
+ image_bytes = uploaded_image.read()
38
+ nparr = np.frombuffer(image_bytes, np.uint8)
39
+ img_ori = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
40
+ img = img_ori.copy()
41
+ st.markdown("<p style='text-align: center; color: red'>Image Uploaded Successfully!</p>", unsafe_allow_html=True)
42
+ # Continue with the rest of the code...
43
+ output_dict, article_ocr_dict = driver(img_ori, language_name, st)
44
+
45
+ # Create a list to store dictionaries for OCR results
46
+ image_data = []
47
+
48
+ # Visualizing Results
49
+ itr = 1
50
+ for art_key in article_ocr_dict:
51
+ art_coords = art_key.split('_')
52
+ art_x1, art_y1, art_x2, art_y2 = int(art_coords[0]), int(art_coords[1]), int(art_coords[2]), int(art_coords[3])
53
+
54
+ # Mark the article bounding box with dark green color
55
+ img_ori = cv2.rectangle(img, (art_x1, art_y1), (art_x2, art_y2), (0, 0, 0), 4)
56
+ # Put the article number on the image in large font
57
+ img_ori = cv2.putText(img_ori, str(itr), (art_x1, art_y1), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 3, cv2.LINE_AA)
58
+ ocr_dict = article_ocr_dict[art_key]
59
+
60
+ # Initialize variables to store OCR text for each type
61
+ headlines_text = ''
62
+ subheadlines_text = ''
63
+ textblocks_text = ''
64
+
65
+ for obj_key in ocr_dict:
66
+ # obj_key is either of Headlines, Sub-headlines, Text Block
67
+ obj_list = ocr_dict[obj_key]
68
+ for obj_dict in obj_list:
69
+ for key in obj_dict:
70
+ coords = key.split('_')
71
+ x1, y1, x2, y2 = int(coords[0]), int(coords[1]), int(coords[2]), int(coords[3])
72
+ # Mark the bounding box with color corresponding to the object type
73
+ img = cv2.rectangle(img, (x1, y1), (x2, y2), colors[obj_key], 2)
74
+
75
+ # Add the OCR text to the corresponding variable
76
+ if obj_key == 'Headlines':
77
+ headlines_text += obj_dict[key] + '\n'
78
+ elif obj_key == 'Sub-headlines':
79
+ subheadlines_text += obj_dict[key] + '\n'
80
+ elif obj_key == 'Text Block':
81
+ textblocks_text += obj_dict[key] + '\n'
82
+
83
+ # Add a dictionary to the list for the current article
84
+ image_data.append({'Article': itr, 'Headlines': headlines_text, 'Subheadlines': subheadlines_text, 'Textblocks': textblocks_text})
85
+
86
+ itr += 1
87
+
88
+ # Create a DataFrame from the list of dictionaries
89
+ image_df = pd.DataFrame(image_data)
90
+
91
+ # Use Streamlit columns to display the image and results side by side
92
+ col1, col2 = st.columns(2)
93
+
94
+ # Display the image with marked bounding boxes in the left column
95
+ col1.image(img_ori, use_column_width=True, channels="BGR", caption="Image with Marked Bounding Boxes")
96
+
97
+ # Display the DataFrame for OCR results for the whole image in the right column
98
+ col2.table(image_df.set_index('Article').style.set_table_styles([{'selector': 'th', 'props': [('text-align', 'center')]}]))
99
+
100
+
101
+ except Exception as e:
102
+ st.exception(e)
best.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64d3fc9d769f8b215e17a4cdf067e47063c1a1b98668e97c6440e5e020c1988f
3
+ size 146238371
cropped/1.png ADDED
cropped/10.png ADDED
cropped/11.png ADDED
cropped/12.png ADDED
cropped/13.png ADDED
cropped/14.png ADDED
cropped/15.png ADDED
cropped/16.png ADDED
cropped/17.png ADDED
cropped/18.png ADDED
cropped/19.png ADDED
cropped/2.png ADDED
cropped/20.png ADDED
cropped/21.png ADDED
cropped/22.png ADDED
cropped/23.png ADDED
cropped/24.png ADDED
cropped/25.png ADDED
cropped/26.png ADDED
cropped/27.png ADDED
cropped/28.png ADDED
cropped/29.png ADDED
cropped/3.png ADDED
cropped/4.png ADDED
cropped/5.png ADDED
cropped/6.png ADDED
cropped/7.png ADDED
cropped/8.png ADDED
cropped/9.png ADDED
main.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from run_yolo import get_layout_results
3
+ from order_text_blocks import get_ordered_data
4
+ from run_ocr import OCR
5
+ from tqdm import tqdm
6
+ import time
7
+
8
+
9
+ def driver(img, language_name, st):
10
+ onnx_path = "./best.onnx"
11
+ img_ori = img.copy()
12
+ labels = get_layout_results(img_ori, onnx_path)
13
+ output_dict = get_ordered_data(labels, img)
14
+ st.markdown("<p style='text-align: center; color: red'>Layout Analysis Completed!</p>", unsafe_allow_html=True)
15
+ article_wise_ocr = {}
16
+ h, w = img.shape[:2]
17
+
18
+ with st.spinner('Performing OCR...'):
19
+ # Add your spinner message with custom CSS
20
+ for itr, article in tqdm(enumerate(output_dict['Articles'])):
21
+ ocr_dict = {}
22
+ article_key = ""
23
+ for key in article:
24
+
25
+ if article[key] == []:
26
+ continue
27
+
28
+ if key == 'Articles':
29
+ x1, y1, x2, y2 = int(article[key][0][0]), int(article[key][0][1]), int(article[key][0][2]), int(article[key][0][3])
30
+ article_key = '_'.join([str(x1), str(y1), str(x2), str(y2)])
31
+
32
+ if key == 'Headlines' or key == 'Sub-headlines' or key == 'Text Block':
33
+ for coord in article[key]:
34
+ x1, y1, x2, y2 = int(coord[0]), int(coord[1]), int(coord[2]), int(coord[3])
35
+ # check if the coordinates are valid, w.r.t image dimensions, if not then skip
36
+ if x1 < 0 or x2 < 0 or y1 < 0 or y2 < 0 or x1 > w or x2 > w or y1 > h or y2 > h:
37
+ continue
38
+
39
+ crop = img[int(coord[1]):int(coord[3]), int(coord[0]):int(coord[2])]
40
+ output_text = OCR(crop, lang=language_name)
41
+
42
+ box_key = "_".join([str(int(coord[0])), str(int(coord[1])), str(int(coord[2])), str(int(coord[3]))])
43
+ if key not in ocr_dict:
44
+ ocr_dict[key] = [{box_key: output_text}]
45
+ else:
46
+ ocr_dict[key].append({box_key: output_text})
47
+
48
+ article_wise_ocr[article_key] = ocr_dict
49
+
50
+ st.markdown("<p style='text-align: center; color: red'>OCR Completed!</p>", unsafe_allow_html=True)
51
+ return output_dict, article_wise_ocr
52
+
order_text_blocks.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import numpy as np
5
+
6
+ class_id_to_name = {
7
+ 0: "Articles",
8
+ 1: "Advertisement",
9
+ 2: "Headlines",
10
+ 3: "Sub-headlines",
11
+ 4: "Graphics",
12
+ 5: "Images",
13
+ 6: "Tables",
14
+ 7: "Text Block",
15
+ 8: "Header"
16
+ }
17
+
18
+
19
+ def findIntersection(box1, box2):
20
+ """
21
+ args:
22
+ box1: [x, y, w, h]
23
+ box2: [x, y, w, h]
24
+
25
+ returns:
26
+ iou: float
27
+ """
28
+ x1, y1, w1, h1 = box1[0], box1[1], box1[2] - box1[0], box1[3] - box1[1]
29
+ x2, y2, w2, h2 = box2[0], box2[1], box2[2] - box2[0], box2[3] - box2[1]
30
+
31
+ xA = max(x1, x2)
32
+ yA = max(y1, y2)
33
+
34
+ xB = min(x1 + w1, x2 + w2)
35
+ yB = min(y1 + h1, y2 + h2)
36
+
37
+ # Calculate the intersection area
38
+ interArea = max(0, xB - xA) * max(0, yB - yA)
39
+
40
+ # divide by obj1 area instead of union
41
+
42
+ if w1 * h1 == 0:
43
+ return 0
44
+ else:
45
+ iou = interArea / (w1 * h1)
46
+
47
+ return iou
48
+
49
+ def get_hierarchy(labels):
50
+ class_list = [[] for _ in range(len(class_id_to_name))]
51
+ object_dict = {}
52
+ article_list = []
53
+
54
+ for itr, label in enumerate(labels):
55
+ x1,y1,x2,y2 = label[0][0], label[0][1], label[0][2], label[0][3]
56
+ conf = label[1]
57
+ class_id = label[2]
58
+ class_list[int(class_id)].append([x1, y1, x2, y2])
59
+ obj_key = int(class_id) * 1000000 + len(class_list[int(class_id)]) - 1
60
+ object_dict[obj_key] = [0, [x1, y1, x2, y2], int(class_id)]
61
+
62
+ # For each article, find all the objects that belong to it
63
+ cou = 0
64
+ for article in class_list[0]:
65
+ article_dict = {'Articles': [], 'Headlines': [], 'Sub-headlines': [], 'Graphics': [], 'Images': [], 'Text Block': [], "Advertisement": [], "Tables": []}
66
+ for class_id in range(9):
67
+ IoUThresh = 0.70; max_article = None
68
+ for obj in class_list[class_id]:
69
+ IoU = findIntersection(obj, article)
70
+ if IoU > IoUThresh:
71
+ key = class_id_to_name[class_id]
72
+
73
+ article_dict[key].append(obj)
74
+ obj_key = class_id * 1000000 + class_list[class_id].index(obj)
75
+ if obj_key in object_dict:
76
+ cou += 1
77
+ object_dict[obj_key][0] = 1
78
+
79
+
80
+ article_list.append(article_dict)
81
+
82
+ return article_list, object_dict
83
+
84
+ def textblock_ordering(textblocks, img):
85
+ coords = []
86
+ ori_coords = {}
87
+ for obj in textblocks:
88
+ x1, y1, x2, y2 = map(float, obj)
89
+ ori_coords[tuple([x1, y1, x2, y2])] = [x1, y1, x2, y2]
90
+ coords.append([x1, y1, x2, y2])
91
+
92
+ # sort the textblocks by x1
93
+ coords.sort(key = lambda x: x[0])
94
+ # create vertical buckets of horizontal pixelsize of 15% of the image width
95
+ buckets = []
96
+ bucket_size = int(0.15 * img.shape[1])
97
+ # put the textblocks in the buckets
98
+ for coord in coords:
99
+ if len(buckets) == 0:
100
+ buckets.append([coord])
101
+ else:
102
+ for bucket in buckets:
103
+ if abs(bucket[0][0] - coord[0]) < bucket_size:
104
+ bucket.append(coord)
105
+ break
106
+ else:
107
+ buckets.append([coord])
108
+
109
+ # sort each bucket by y1
110
+ for bucket in buckets:
111
+ bucket.sort(key = lambda x: x[1])
112
+
113
+ # visualize the buckets one by one each with a different color
114
+ # for bucket in buckets:
115
+ # color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
116
+ # for coord in bucket:
117
+ # img = cv2.rectangle(img, (coord[0], coord[1]), (coord[2], coord[3]), color, 5)
118
+ # cv2.imshow('img', img)
119
+
120
+ # change bucket coords to original coords
121
+ for bucket in buckets:
122
+ for i in range(len(bucket)):
123
+ bucket[i] = ori_coords[tuple(bucket[i])]
124
+
125
+ # merge all the buckets into one list
126
+ buckets = [item for sublist in buckets for item in sublist]
127
+ return buckets
128
+
129
+ def get_ordered_data(labels, img):
130
+ article_list, object_dict = get_hierarchy(labels)
131
+ for article in article_list:
132
+ sorted_buckets = textblock_ordering(article['Text Block'], img)
133
+ article['Text Block'] = sorted_buckets
134
+
135
+ # Dump the results in a json file
136
+ # Data structure:
137
+ # {Article1: {Headlines: [obj1, obj2, ...], Sub-headlines: [obj1, obj2, ...], ...}, Article2: {...}, ...}
138
+ json_dict = {}
139
+ json_dict['Articles'] = article_list
140
+ json_dict['Extra'] = []
141
+ # Add remaining objects to the json
142
+ for key in object_dict:
143
+ if object_dict[key][0] == 0:
144
+ print("Extra: ", key)
145
+ json_dict['Extra'].append({class_id_to_name[object_dict[key][2]]: [object_dict[key][1]]})
146
+
147
+ return json_dict
148
+
149
+ # if __name__ == '__main__':
150
+ # label_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Results/pred/Hindi2/labels/_Jansatta-Delhi 15-11_5.txt'
151
+ # img_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Language_wise_imgs/Hindi/_Jansatta-Delhi 15-11_5.png'
152
+ # json_dict = get_ordered_data(label_path, img_path)
153
+
154
+ # # dump the json
155
+ # with open('json_dict.json', 'w') as f:
156
+ # json.dump(json_dict, f)
157
+
158
+ # # read the json
159
+ # with open('json_dict.json', 'r') as f:
160
+ # json_dict = json.load(f)
161
+
162
+ # visualize the results
163
+ img = cv2.imread(img_path)
164
+
165
+
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ tesseract-ocr-all
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.2
2
+ onnxruntime==1.15.1
3
+ onnxruntime_gpu==1.16.3
4
+ opencv_contrib_python==4.8.1.78
5
+ opencv_python==4.8.1.78
6
+ pandas==2.1.4
7
+ Pillow
8
+ pytesseract==0.3.10
9
+ Requests==2.31.0
10
+ streamlit==1.24.0
11
+ torch==2.0.1
12
+ torchvision==0.15.2
13
+ tqdm==4.65.0
run_ocr.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy
3
+ import argparse
4
+ from pytesseract import*
5
+ from PIL import Image, ImageFont, ImageDraw
6
+ import numpy as np
7
+
8
+
9
+
10
+ # def preprocess_image(image):
11
+
12
+
13
+
14
+ def OCR(img, lang='hin', min_conf=0.25):
15
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
16
+ # preprocessed_image = preprocess_image(rgb)
17
+ # write the preprocessed image to disk as a temporary file so we can
18
+ results = pytesseract.image_to_data(rgb, output_type=Output.DICT, lang=lang)
19
+ out_text = ""
20
+ for i in range(0, len(results["text"])):
21
+
22
+ # We can then extract the bounding box coordinates
23
+ # of the text region from the current result
24
+ x = results["left"][i]
25
+ y = results["top"][i]
26
+ w = results["width"][i]
27
+ h = results["height"][i]
28
+
29
+ # We will also extract the OCR text itself along
30
+ # with the confidence of the text localization
31
+ text = results["text"][i]
32
+ conf = int(results["conf"][i])
33
+
34
+ # filter out weak confidence text localizations
35
+ if conf > min_conf:
36
+ # We then strip out non-ASCII text so we can
37
+ # draw the text on the image We will be using
38
+ # OpenCV, then draw a bounding box around the
39
+ # text along with the text itself
40
+ text = "".join(text).strip()
41
+ out_text += text + " "
42
+
43
+ return out_text
run_yolo.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+ import requests
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image
7
+ from pathlib import Path
8
+ from collections import OrderedDict,namedtuple
9
+ import onnxruntime as ort
10
+ import torch
11
+ import torchvision
12
+ import math
13
+
14
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
15
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
16
+ box2 = box2.T
17
+
18
+ # Get the coordinates of bounding boxes
19
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
20
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
21
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
22
+ else: # transform from xywh to xyxy
23
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
24
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
25
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
26
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
27
+
28
+ # Intersection area
29
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
30
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
31
+
32
+ # Union Area
33
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
34
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
35
+ union = w1 * h1 + w2 * h2 - inter + eps
36
+
37
+ iou = inter / union
38
+
39
+ if GIoU or DIoU or CIoU:
40
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
41
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
42
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
43
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
44
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
45
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
46
+ if DIoU:
47
+ return iou - rho2 / c2 # DIoU
48
+ elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
49
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
50
+ with torch.no_grad():
51
+ alpha = v / (v - iou + (1 + eps))
52
+ return iou - (rho2 / c2 + v * alpha) # CIoU
53
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
54
+ c_area = cw * ch + eps # convex area
55
+ return iou - (c_area - union) / c_area # GIoU
56
+ else:
57
+ return iou # IoU
58
+
59
+
60
+ def xywh2xyxy(x):
61
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
62
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
63
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
64
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
65
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
66
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
67
+ return y
68
+
69
+
70
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
71
+ labels=()):
72
+ """Runs Non-Maximum Suppression (NMS) on inference results
73
+
74
+ Returns:
75
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
76
+ """
77
+
78
+ nc = prediction.shape[2] - 5 # number of classes
79
+ xc = prediction[..., 4] > conf_thres # candidates
80
+
81
+ # Settings
82
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
83
+ max_det = 300 # maximum number of detections per image
84
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
85
+ time_limit = 10.0 # seconds to quit after
86
+ redundant = True # require redundant detections
87
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
88
+ merge = False # use merge-NMS
89
+
90
+ t = time.time()
91
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
92
+ for xi, x in enumerate(prediction): # image index, image inference
93
+ # Apply constraints
94
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
95
+ x = x[xc[xi]] # confidence
96
+
97
+ # Cat apriori labels if autolabelling
98
+ if labels and len(labels[xi]):
99
+ l = labels[xi]
100
+ v = torch.zeros((len(l), nc + 5), device=x.device)
101
+ v[:, :4] = l[:, 1:5] # box
102
+ v[:, 4] = 1.0 # conf
103
+ v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
104
+ x = torch.cat((x, v), 0)
105
+
106
+ # If none remain process next image
107
+ if not x.shape[0]:
108
+ continue
109
+
110
+ # Compute conf
111
+ if nc == 1:
112
+ x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
113
+ # so there is no need to multiplicate.
114
+ else:
115
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
116
+
117
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
118
+ box = xywh2xyxy(x[:, :4])
119
+
120
+ # Detections matrix nx6 (xyxy, conf, cls)
121
+ if multi_label:
122
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
123
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
124
+ else: # best class only
125
+ conf, j = x[:, 5:].max(1, keepdim=True)
126
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
127
+
128
+ # Filter by class
129
+ if classes is not None:
130
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
131
+
132
+ # Apply finite constraint
133
+ # if not torch.isfinite(x).all():
134
+ # x = x[torch.isfinite(x).all(1)]
135
+
136
+ # Check shape
137
+ n = x.shape[0] # number of boxes
138
+ if not n: # no boxes
139
+ continue
140
+ elif n > max_nms: # excess boxes
141
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
142
+
143
+ # Batched NMS
144
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
145
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
146
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
147
+ if i.shape[0] > max_det: # limit detections
148
+ i = i[:max_det]
149
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
150
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
151
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
152
+ weights = iou * scores[None] # box weights
153
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
154
+ if redundant:
155
+ i = i[iou.sum(1) > 1] # require redundancy
156
+
157
+ output[xi] = x[i]
158
+ if (time.time() - t) > time_limit:
159
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
160
+ break # time limit exceeded
161
+
162
+ return output
163
+
164
+
165
+ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
166
+ # Resize and pad image while meeting stride-multiple constraints
167
+ shape = im.shape[:2] # current shape [height, width]
168
+ if isinstance(new_shape, int):
169
+ new_shape = (new_shape, new_shape)
170
+
171
+ # Scale ratio (new / old)
172
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
173
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
174
+ r = min(r, 1.0)
175
+
176
+ # Compute padding
177
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
178
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
179
+
180
+ if auto: # minimum rectangle
181
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
182
+
183
+ dw /= 2 # divide padding into 2 sides
184
+ dh /= 2
185
+
186
+ if shape[::-1] != new_unpad: # resize
187
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
188
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
189
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
190
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
191
+ return im, r, (dw, dh)
192
+
193
+
194
+ def get_layout_results(img, onnx_path):
195
+ providers = ['CPUExecutionProvider']
196
+ session = ort.InferenceSession(onnx_path, providers=providers)
197
+ names = ['Articles', 'Advertisement', 'Headlines', 'Sub-headlines', 'Graphics', 'Images', 'Tables', 'Text Block', 'Header']
198
+ # colors = {name:[random.randint(0, 255) for _ in range(3)] for i,name in enumerate(names)}
199
+ # instead of random color, use specific easily distinguishable colors for each class
200
+ colors = {
201
+ 'Articles': [255, 0, 0], # Red
202
+ 'Advertisement': [0, 255, 0], # Green
203
+ 'Headlines': [0, 0, 255], # Blue
204
+ 'Sub-headlines': [255, 255, 0], # Yellow
205
+ 'Graphics': [255, 0, 255], # Magenta
206
+ 'Images': [128, 0, 128], # Purple
207
+ 'Tables': [0, 255, 255], # Teal
208
+ 'Text Block': [0, 128, 128], # Navy
209
+ 'Header': [0, 0, 0] # Black
210
+ }
211
+
212
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
213
+ image = img.copy()
214
+ image, ratio, dwdh = letterbox(image, auto=False)
215
+ image = image.transpose((2, 0, 1))
216
+ image = np.expand_dims(image, 0)
217
+ image = np.ascontiguousarray(image)
218
+ im = image.astype(np.float32)
219
+ im /= 255.0
220
+ outname = [i.name for i in session.get_outputs()]
221
+ inname = [i.name for i in session.get_inputs()]
222
+ inp = {inname[0]:im}
223
+
224
+ # ONNX inference
225
+ outputs = session.run(outname, inp)[0]
226
+ # convert to torch tensor
227
+ outputs = torch.from_numpy(outputs)
228
+ det = non_max_suppression(outputs, 0.25, 0.45, classes=None, agnostic=False)[0] # conf_thres=0.25, iou_thres=0.45
229
+ results = []
230
+ # postprocess the output
231
+ for i,(x0,y0,x1,y1,score,cls_id) in enumerate(det):
232
+ box = np.array([x0,y0,x1,y1])
233
+ box -= np.array(dwdh*2)
234
+ box /= ratio
235
+ box = box.round().astype(np.int32).tolist()
236
+ cls_id = int(cls_id)
237
+ score = round(float(score),3)
238
+ name = names[cls_id]
239
+ color = colors[name]
240
+ results.append([box, score, cls_id, color])
241
+
242
+ return results
243
+
244
+ if __name__ == '__main__':
245
+ onnx_path = "/home/ubuntu/devesh/yolov7/runs/train/yolov7-custom9/weights/best.onnx"
246
+ img_ori = cv2.imread('/home/ubuntu/devesh/yolov7/Language_wise_imgs/Hindi/_Dainik_Navajyoti_-_04-11-2023_3.png')
247
+ lines = get_layout_results(img_ori, onnx_path)
248
+ print(lines[0])
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.8.7