Spaces:
Runtime error
Runtime error
Devesh Pant
commited on
Commit
•
1b870f4
1
Parent(s):
589dd63
v0
Browse files- README.md +3 -3
- app.py +102 -0
- best.onnx +3 -0
- cropped/1.png +0 -0
- cropped/10.png +0 -0
- cropped/11.png +0 -0
- cropped/12.png +0 -0
- cropped/13.png +0 -0
- cropped/14.png +0 -0
- cropped/15.png +0 -0
- cropped/16.png +0 -0
- cropped/17.png +0 -0
- cropped/18.png +0 -0
- cropped/19.png +0 -0
- cropped/2.png +0 -0
- cropped/20.png +0 -0
- cropped/21.png +0 -0
- cropped/22.png +0 -0
- cropped/23.png +0 -0
- cropped/24.png +0 -0
- cropped/25.png +0 -0
- cropped/26.png +0 -0
- cropped/27.png +0 -0
- cropped/28.png +0 -0
- cropped/29.png +0 -0
- cropped/3.png +0 -0
- cropped/4.png +0 -0
- cropped/5.png +0 -0
- cropped/6.png +0 -0
- cropped/7.png +0 -0
- cropped/8.png +0 -0
- cropped/9.png +0 -0
- main.py +52 -0
- order_text_blocks.py +165 -0
- packages.txt +2 -0
- requirements.txt +13 -0
- run_ocr.py +43 -0
- run_yolo.py +248 -0
- runtime.txt +1 -0
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.29.0
|
|
|
1 |
---
|
2 |
+
title: Newspaper Demo
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.29.0
|
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
from run_yolo import get_layout_results
|
5 |
+
from order_text_blocks import get_ordered_data
|
6 |
+
from run_ocr import OCR
|
7 |
+
from main import driver
|
8 |
+
import json
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
colors = {
|
12 |
+
'Articles': [0, 0, 0], # Red
|
13 |
+
'Advertisement': [0, 255, 0], # Green
|
14 |
+
'Headlines': [0, 0, 255], # Blue
|
15 |
+
'Sub-headlines': [255, 255, 0], # Yellow
|
16 |
+
'Graphics': [255, 0, 255], # Magenta
|
17 |
+
'Images': [128, 0, 128], # Purple
|
18 |
+
'Tables': [0, 255, 255], # Teal
|
19 |
+
'Header': [0, 0, 0], # Black
|
20 |
+
'Text Block': [255, 0, 0]
|
21 |
+
}
|
22 |
+
|
23 |
+
try:
|
24 |
+
st.set_page_config(layout="wide", page_title="Newspaper Layout Detection and OCR Demo")
|
25 |
+
st.markdown("<h1 style='text-align: center; color: #333;'>Newspaper Layout Detection and OCR Demo</h1>", unsafe_allow_html=True)
|
26 |
+
|
27 |
+
# Streamlit UI for user input
|
28 |
+
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
29 |
+
language_name = st.text_input("Enter the language name: (hin, en, mal, tel, tam, kan))")
|
30 |
+
submit_button = st.button("Submit")
|
31 |
+
|
32 |
+
# Check if the user clicked the submit button
|
33 |
+
if submit_button:
|
34 |
+
# Check if image and language are provided
|
35 |
+
if uploaded_image is not None and language_name:
|
36 |
+
# Convert Streamlit file uploader to OpenCV image
|
37 |
+
image_bytes = uploaded_image.read()
|
38 |
+
nparr = np.frombuffer(image_bytes, np.uint8)
|
39 |
+
img_ori = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
40 |
+
img = img_ori.copy()
|
41 |
+
st.markdown("<p style='text-align: center; color: red'>Image Uploaded Successfully!</p>", unsafe_allow_html=True)
|
42 |
+
# Continue with the rest of the code...
|
43 |
+
output_dict, article_ocr_dict = driver(img_ori, language_name, st)
|
44 |
+
|
45 |
+
# Create a list to store dictionaries for OCR results
|
46 |
+
image_data = []
|
47 |
+
|
48 |
+
# Visualizing Results
|
49 |
+
itr = 1
|
50 |
+
for art_key in article_ocr_dict:
|
51 |
+
art_coords = art_key.split('_')
|
52 |
+
art_x1, art_y1, art_x2, art_y2 = int(art_coords[0]), int(art_coords[1]), int(art_coords[2]), int(art_coords[3])
|
53 |
+
|
54 |
+
# Mark the article bounding box with dark green color
|
55 |
+
img_ori = cv2.rectangle(img, (art_x1, art_y1), (art_x2, art_y2), (0, 0, 0), 4)
|
56 |
+
# Put the article number on the image in large font
|
57 |
+
img_ori = cv2.putText(img_ori, str(itr), (art_x1, art_y1), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 3, cv2.LINE_AA)
|
58 |
+
ocr_dict = article_ocr_dict[art_key]
|
59 |
+
|
60 |
+
# Initialize variables to store OCR text for each type
|
61 |
+
headlines_text = ''
|
62 |
+
subheadlines_text = ''
|
63 |
+
textblocks_text = ''
|
64 |
+
|
65 |
+
for obj_key in ocr_dict:
|
66 |
+
# obj_key is either of Headlines, Sub-headlines, Text Block
|
67 |
+
obj_list = ocr_dict[obj_key]
|
68 |
+
for obj_dict in obj_list:
|
69 |
+
for key in obj_dict:
|
70 |
+
coords = key.split('_')
|
71 |
+
x1, y1, x2, y2 = int(coords[0]), int(coords[1]), int(coords[2]), int(coords[3])
|
72 |
+
# Mark the bounding box with color corresponding to the object type
|
73 |
+
img = cv2.rectangle(img, (x1, y1), (x2, y2), colors[obj_key], 2)
|
74 |
+
|
75 |
+
# Add the OCR text to the corresponding variable
|
76 |
+
if obj_key == 'Headlines':
|
77 |
+
headlines_text += obj_dict[key] + '\n'
|
78 |
+
elif obj_key == 'Sub-headlines':
|
79 |
+
subheadlines_text += obj_dict[key] + '\n'
|
80 |
+
elif obj_key == 'Text Block':
|
81 |
+
textblocks_text += obj_dict[key] + '\n'
|
82 |
+
|
83 |
+
# Add a dictionary to the list for the current article
|
84 |
+
image_data.append({'Article': itr, 'Headlines': headlines_text, 'Subheadlines': subheadlines_text, 'Textblocks': textblocks_text})
|
85 |
+
|
86 |
+
itr += 1
|
87 |
+
|
88 |
+
# Create a DataFrame from the list of dictionaries
|
89 |
+
image_df = pd.DataFrame(image_data)
|
90 |
+
|
91 |
+
# Use Streamlit columns to display the image and results side by side
|
92 |
+
col1, col2 = st.columns(2)
|
93 |
+
|
94 |
+
# Display the image with marked bounding boxes in the left column
|
95 |
+
col1.image(img_ori, use_column_width=True, channels="BGR", caption="Image with Marked Bounding Boxes")
|
96 |
+
|
97 |
+
# Display the DataFrame for OCR results for the whole image in the right column
|
98 |
+
col2.table(image_df.set_index('Article').style.set_table_styles([{'selector': 'th', 'props': [('text-align', 'center')]}]))
|
99 |
+
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
st.exception(e)
|
best.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64d3fc9d769f8b215e17a4cdf067e47063c1a1b98668e97c6440e5e020c1988f
|
3 |
+
size 146238371
|
cropped/1.png
ADDED
cropped/10.png
ADDED
cropped/11.png
ADDED
cropped/12.png
ADDED
cropped/13.png
ADDED
cropped/14.png
ADDED
cropped/15.png
ADDED
cropped/16.png
ADDED
cropped/17.png
ADDED
cropped/18.png
ADDED
cropped/19.png
ADDED
cropped/2.png
ADDED
cropped/20.png
ADDED
cropped/21.png
ADDED
cropped/22.png
ADDED
cropped/23.png
ADDED
cropped/24.png
ADDED
cropped/25.png
ADDED
cropped/26.png
ADDED
cropped/27.png
ADDED
cropped/28.png
ADDED
cropped/29.png
ADDED
cropped/3.png
ADDED
cropped/4.png
ADDED
cropped/5.png
ADDED
cropped/6.png
ADDED
cropped/7.png
ADDED
cropped/8.png
ADDED
cropped/9.png
ADDED
main.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from run_yolo import get_layout_results
|
3 |
+
from order_text_blocks import get_ordered_data
|
4 |
+
from run_ocr import OCR
|
5 |
+
from tqdm import tqdm
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def driver(img, language_name, st):
|
10 |
+
onnx_path = "./best.onnx"
|
11 |
+
img_ori = img.copy()
|
12 |
+
labels = get_layout_results(img_ori, onnx_path)
|
13 |
+
output_dict = get_ordered_data(labels, img)
|
14 |
+
st.markdown("<p style='text-align: center; color: red'>Layout Analysis Completed!</p>", unsafe_allow_html=True)
|
15 |
+
article_wise_ocr = {}
|
16 |
+
h, w = img.shape[:2]
|
17 |
+
|
18 |
+
with st.spinner('Performing OCR...'):
|
19 |
+
# Add your spinner message with custom CSS
|
20 |
+
for itr, article in tqdm(enumerate(output_dict['Articles'])):
|
21 |
+
ocr_dict = {}
|
22 |
+
article_key = ""
|
23 |
+
for key in article:
|
24 |
+
|
25 |
+
if article[key] == []:
|
26 |
+
continue
|
27 |
+
|
28 |
+
if key == 'Articles':
|
29 |
+
x1, y1, x2, y2 = int(article[key][0][0]), int(article[key][0][1]), int(article[key][0][2]), int(article[key][0][3])
|
30 |
+
article_key = '_'.join([str(x1), str(y1), str(x2), str(y2)])
|
31 |
+
|
32 |
+
if key == 'Headlines' or key == 'Sub-headlines' or key == 'Text Block':
|
33 |
+
for coord in article[key]:
|
34 |
+
x1, y1, x2, y2 = int(coord[0]), int(coord[1]), int(coord[2]), int(coord[3])
|
35 |
+
# check if the coordinates are valid, w.r.t image dimensions, if not then skip
|
36 |
+
if x1 < 0 or x2 < 0 or y1 < 0 or y2 < 0 or x1 > w or x2 > w or y1 > h or y2 > h:
|
37 |
+
continue
|
38 |
+
|
39 |
+
crop = img[int(coord[1]):int(coord[3]), int(coord[0]):int(coord[2])]
|
40 |
+
output_text = OCR(crop, lang=language_name)
|
41 |
+
|
42 |
+
box_key = "_".join([str(int(coord[0])), str(int(coord[1])), str(int(coord[2])), str(int(coord[3]))])
|
43 |
+
if key not in ocr_dict:
|
44 |
+
ocr_dict[key] = [{box_key: output_text}]
|
45 |
+
else:
|
46 |
+
ocr_dict[key].append({box_key: output_text})
|
47 |
+
|
48 |
+
article_wise_ocr[article_key] = ocr_dict
|
49 |
+
|
50 |
+
st.markdown("<p style='text-align: center; color: red'>OCR Completed!</p>", unsafe_allow_html=True)
|
51 |
+
return output_dict, article_wise_ocr
|
52 |
+
|
order_text_blocks.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class_id_to_name = {
|
7 |
+
0: "Articles",
|
8 |
+
1: "Advertisement",
|
9 |
+
2: "Headlines",
|
10 |
+
3: "Sub-headlines",
|
11 |
+
4: "Graphics",
|
12 |
+
5: "Images",
|
13 |
+
6: "Tables",
|
14 |
+
7: "Text Block",
|
15 |
+
8: "Header"
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def findIntersection(box1, box2):
|
20 |
+
"""
|
21 |
+
args:
|
22 |
+
box1: [x, y, w, h]
|
23 |
+
box2: [x, y, w, h]
|
24 |
+
|
25 |
+
returns:
|
26 |
+
iou: float
|
27 |
+
"""
|
28 |
+
x1, y1, w1, h1 = box1[0], box1[1], box1[2] - box1[0], box1[3] - box1[1]
|
29 |
+
x2, y2, w2, h2 = box2[0], box2[1], box2[2] - box2[0], box2[3] - box2[1]
|
30 |
+
|
31 |
+
xA = max(x1, x2)
|
32 |
+
yA = max(y1, y2)
|
33 |
+
|
34 |
+
xB = min(x1 + w1, x2 + w2)
|
35 |
+
yB = min(y1 + h1, y2 + h2)
|
36 |
+
|
37 |
+
# Calculate the intersection area
|
38 |
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
39 |
+
|
40 |
+
# divide by obj1 area instead of union
|
41 |
+
|
42 |
+
if w1 * h1 == 0:
|
43 |
+
return 0
|
44 |
+
else:
|
45 |
+
iou = interArea / (w1 * h1)
|
46 |
+
|
47 |
+
return iou
|
48 |
+
|
49 |
+
def get_hierarchy(labels):
|
50 |
+
class_list = [[] for _ in range(len(class_id_to_name))]
|
51 |
+
object_dict = {}
|
52 |
+
article_list = []
|
53 |
+
|
54 |
+
for itr, label in enumerate(labels):
|
55 |
+
x1,y1,x2,y2 = label[0][0], label[0][1], label[0][2], label[0][3]
|
56 |
+
conf = label[1]
|
57 |
+
class_id = label[2]
|
58 |
+
class_list[int(class_id)].append([x1, y1, x2, y2])
|
59 |
+
obj_key = int(class_id) * 1000000 + len(class_list[int(class_id)]) - 1
|
60 |
+
object_dict[obj_key] = [0, [x1, y1, x2, y2], int(class_id)]
|
61 |
+
|
62 |
+
# For each article, find all the objects that belong to it
|
63 |
+
cou = 0
|
64 |
+
for article in class_list[0]:
|
65 |
+
article_dict = {'Articles': [], 'Headlines': [], 'Sub-headlines': [], 'Graphics': [], 'Images': [], 'Text Block': [], "Advertisement": [], "Tables": []}
|
66 |
+
for class_id in range(9):
|
67 |
+
IoUThresh = 0.70; max_article = None
|
68 |
+
for obj in class_list[class_id]:
|
69 |
+
IoU = findIntersection(obj, article)
|
70 |
+
if IoU > IoUThresh:
|
71 |
+
key = class_id_to_name[class_id]
|
72 |
+
|
73 |
+
article_dict[key].append(obj)
|
74 |
+
obj_key = class_id * 1000000 + class_list[class_id].index(obj)
|
75 |
+
if obj_key in object_dict:
|
76 |
+
cou += 1
|
77 |
+
object_dict[obj_key][0] = 1
|
78 |
+
|
79 |
+
|
80 |
+
article_list.append(article_dict)
|
81 |
+
|
82 |
+
return article_list, object_dict
|
83 |
+
|
84 |
+
def textblock_ordering(textblocks, img):
|
85 |
+
coords = []
|
86 |
+
ori_coords = {}
|
87 |
+
for obj in textblocks:
|
88 |
+
x1, y1, x2, y2 = map(float, obj)
|
89 |
+
ori_coords[tuple([x1, y1, x2, y2])] = [x1, y1, x2, y2]
|
90 |
+
coords.append([x1, y1, x2, y2])
|
91 |
+
|
92 |
+
# sort the textblocks by x1
|
93 |
+
coords.sort(key = lambda x: x[0])
|
94 |
+
# create vertical buckets of horizontal pixelsize of 15% of the image width
|
95 |
+
buckets = []
|
96 |
+
bucket_size = int(0.15 * img.shape[1])
|
97 |
+
# put the textblocks in the buckets
|
98 |
+
for coord in coords:
|
99 |
+
if len(buckets) == 0:
|
100 |
+
buckets.append([coord])
|
101 |
+
else:
|
102 |
+
for bucket in buckets:
|
103 |
+
if abs(bucket[0][0] - coord[0]) < bucket_size:
|
104 |
+
bucket.append(coord)
|
105 |
+
break
|
106 |
+
else:
|
107 |
+
buckets.append([coord])
|
108 |
+
|
109 |
+
# sort each bucket by y1
|
110 |
+
for bucket in buckets:
|
111 |
+
bucket.sort(key = lambda x: x[1])
|
112 |
+
|
113 |
+
# visualize the buckets one by one each with a different color
|
114 |
+
# for bucket in buckets:
|
115 |
+
# color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
|
116 |
+
# for coord in bucket:
|
117 |
+
# img = cv2.rectangle(img, (coord[0], coord[1]), (coord[2], coord[3]), color, 5)
|
118 |
+
# cv2.imshow('img', img)
|
119 |
+
|
120 |
+
# change bucket coords to original coords
|
121 |
+
for bucket in buckets:
|
122 |
+
for i in range(len(bucket)):
|
123 |
+
bucket[i] = ori_coords[tuple(bucket[i])]
|
124 |
+
|
125 |
+
# merge all the buckets into one list
|
126 |
+
buckets = [item for sublist in buckets for item in sublist]
|
127 |
+
return buckets
|
128 |
+
|
129 |
+
def get_ordered_data(labels, img):
|
130 |
+
article_list, object_dict = get_hierarchy(labels)
|
131 |
+
for article in article_list:
|
132 |
+
sorted_buckets = textblock_ordering(article['Text Block'], img)
|
133 |
+
article['Text Block'] = sorted_buckets
|
134 |
+
|
135 |
+
# Dump the results in a json file
|
136 |
+
# Data structure:
|
137 |
+
# {Article1: {Headlines: [obj1, obj2, ...], Sub-headlines: [obj1, obj2, ...], ...}, Article2: {...}, ...}
|
138 |
+
json_dict = {}
|
139 |
+
json_dict['Articles'] = article_list
|
140 |
+
json_dict['Extra'] = []
|
141 |
+
# Add remaining objects to the json
|
142 |
+
for key in object_dict:
|
143 |
+
if object_dict[key][0] == 0:
|
144 |
+
print("Extra: ", key)
|
145 |
+
json_dict['Extra'].append({class_id_to_name[object_dict[key][2]]: [object_dict[key][1]]})
|
146 |
+
|
147 |
+
return json_dict
|
148 |
+
|
149 |
+
# if __name__ == '__main__':
|
150 |
+
# label_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Results/pred/Hindi2/labels/_Jansatta-Delhi 15-11_5.txt'
|
151 |
+
# img_path = '/Users/deveshpant/Work/WadhwaniAI/IDSP/eNewspaperPDFs/Language_wise/Language_wise_imgs/Hindi/_Jansatta-Delhi 15-11_5.png'
|
152 |
+
# json_dict = get_ordered_data(label_path, img_path)
|
153 |
+
|
154 |
+
# # dump the json
|
155 |
+
# with open('json_dict.json', 'w') as f:
|
156 |
+
# json.dump(json_dict, f)
|
157 |
+
|
158 |
+
# # read the json
|
159 |
+
# with open('json_dict.json', 'r') as f:
|
160 |
+
# json_dict = json.load(f)
|
161 |
+
|
162 |
+
# visualize the results
|
163 |
+
img = cv2.imread(img_path)
|
164 |
+
|
165 |
+
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
tesseract-ocr-all
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.2
|
2 |
+
onnxruntime==1.15.1
|
3 |
+
onnxruntime_gpu==1.16.3
|
4 |
+
opencv_contrib_python==4.8.1.78
|
5 |
+
opencv_python==4.8.1.78
|
6 |
+
pandas==2.1.4
|
7 |
+
Pillow
|
8 |
+
pytesseract==0.3.10
|
9 |
+
Requests==2.31.0
|
10 |
+
streamlit==1.24.0
|
11 |
+
torch==2.0.1
|
12 |
+
torchvision==0.15.2
|
13 |
+
tqdm==4.65.0
|
run_ocr.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy
|
3 |
+
import argparse
|
4 |
+
from pytesseract import*
|
5 |
+
from PIL import Image, ImageFont, ImageDraw
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
# def preprocess_image(image):
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def OCR(img, lang='hin', min_conf=0.25):
|
15 |
+
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
16 |
+
# preprocessed_image = preprocess_image(rgb)
|
17 |
+
# write the preprocessed image to disk as a temporary file so we can
|
18 |
+
results = pytesseract.image_to_data(rgb, output_type=Output.DICT, lang=lang)
|
19 |
+
out_text = ""
|
20 |
+
for i in range(0, len(results["text"])):
|
21 |
+
|
22 |
+
# We can then extract the bounding box coordinates
|
23 |
+
# of the text region from the current result
|
24 |
+
x = results["left"][i]
|
25 |
+
y = results["top"][i]
|
26 |
+
w = results["width"][i]
|
27 |
+
h = results["height"][i]
|
28 |
+
|
29 |
+
# We will also extract the OCR text itself along
|
30 |
+
# with the confidence of the text localization
|
31 |
+
text = results["text"][i]
|
32 |
+
conf = int(results["conf"][i])
|
33 |
+
|
34 |
+
# filter out weak confidence text localizations
|
35 |
+
if conf > min_conf:
|
36 |
+
# We then strip out non-ASCII text so we can
|
37 |
+
# draw the text on the image We will be using
|
38 |
+
# OpenCV, then draw a bounding box around the
|
39 |
+
# text along with the text itself
|
40 |
+
text = "".join(text).strip()
|
41 |
+
out_text += text + " "
|
42 |
+
|
43 |
+
return out_text
|
run_yolo.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import time
|
3 |
+
import requests
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from pathlib import Path
|
8 |
+
from collections import OrderedDict,namedtuple
|
9 |
+
import onnxruntime as ort
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
import math
|
13 |
+
|
14 |
+
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
15 |
+
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
16 |
+
box2 = box2.T
|
17 |
+
|
18 |
+
# Get the coordinates of bounding boxes
|
19 |
+
if x1y1x2y2: # x1, y1, x2, y2 = box1
|
20 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
21 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
22 |
+
else: # transform from xywh to xyxy
|
23 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
24 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
25 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
26 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
27 |
+
|
28 |
+
# Intersection area
|
29 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
30 |
+
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
31 |
+
|
32 |
+
# Union Area
|
33 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
34 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
35 |
+
union = w1 * h1 + w2 * h2 - inter + eps
|
36 |
+
|
37 |
+
iou = inter / union
|
38 |
+
|
39 |
+
if GIoU or DIoU or CIoU:
|
40 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
41 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
42 |
+
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
43 |
+
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
44 |
+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
45 |
+
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
|
46 |
+
if DIoU:
|
47 |
+
return iou - rho2 / c2 # DIoU
|
48 |
+
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
49 |
+
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
|
50 |
+
with torch.no_grad():
|
51 |
+
alpha = v / (v - iou + (1 + eps))
|
52 |
+
return iou - (rho2 / c2 + v * alpha) # CIoU
|
53 |
+
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
54 |
+
c_area = cw * ch + eps # convex area
|
55 |
+
return iou - (c_area - union) / c_area # GIoU
|
56 |
+
else:
|
57 |
+
return iou # IoU
|
58 |
+
|
59 |
+
|
60 |
+
def xywh2xyxy(x):
|
61 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
62 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
63 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
64 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
65 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
66 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
67 |
+
return y
|
68 |
+
|
69 |
+
|
70 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
|
71 |
+
labels=()):
|
72 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
76 |
+
"""
|
77 |
+
|
78 |
+
nc = prediction.shape[2] - 5 # number of classes
|
79 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
80 |
+
|
81 |
+
# Settings
|
82 |
+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
83 |
+
max_det = 300 # maximum number of detections per image
|
84 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
85 |
+
time_limit = 10.0 # seconds to quit after
|
86 |
+
redundant = True # require redundant detections
|
87 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
88 |
+
merge = False # use merge-NMS
|
89 |
+
|
90 |
+
t = time.time()
|
91 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
92 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
93 |
+
# Apply constraints
|
94 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
95 |
+
x = x[xc[xi]] # confidence
|
96 |
+
|
97 |
+
# Cat apriori labels if autolabelling
|
98 |
+
if labels and len(labels[xi]):
|
99 |
+
l = labels[xi]
|
100 |
+
v = torch.zeros((len(l), nc + 5), device=x.device)
|
101 |
+
v[:, :4] = l[:, 1:5] # box
|
102 |
+
v[:, 4] = 1.0 # conf
|
103 |
+
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
|
104 |
+
x = torch.cat((x, v), 0)
|
105 |
+
|
106 |
+
# If none remain process next image
|
107 |
+
if not x.shape[0]:
|
108 |
+
continue
|
109 |
+
|
110 |
+
# Compute conf
|
111 |
+
if nc == 1:
|
112 |
+
x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
|
113 |
+
# so there is no need to multiplicate.
|
114 |
+
else:
|
115 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
116 |
+
|
117 |
+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
118 |
+
box = xywh2xyxy(x[:, :4])
|
119 |
+
|
120 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
121 |
+
if multi_label:
|
122 |
+
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
123 |
+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
124 |
+
else: # best class only
|
125 |
+
conf, j = x[:, 5:].max(1, keepdim=True)
|
126 |
+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
127 |
+
|
128 |
+
# Filter by class
|
129 |
+
if classes is not None:
|
130 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
131 |
+
|
132 |
+
# Apply finite constraint
|
133 |
+
# if not torch.isfinite(x).all():
|
134 |
+
# x = x[torch.isfinite(x).all(1)]
|
135 |
+
|
136 |
+
# Check shape
|
137 |
+
n = x.shape[0] # number of boxes
|
138 |
+
if not n: # no boxes
|
139 |
+
continue
|
140 |
+
elif n > max_nms: # excess boxes
|
141 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
142 |
+
|
143 |
+
# Batched NMS
|
144 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
145 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
146 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
147 |
+
if i.shape[0] > max_det: # limit detections
|
148 |
+
i = i[:max_det]
|
149 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
150 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
151 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
152 |
+
weights = iou * scores[None] # box weights
|
153 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
154 |
+
if redundant:
|
155 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
156 |
+
|
157 |
+
output[xi] = x[i]
|
158 |
+
if (time.time() - t) > time_limit:
|
159 |
+
print(f'WARNING: NMS time limit {time_limit}s exceeded')
|
160 |
+
break # time limit exceeded
|
161 |
+
|
162 |
+
return output
|
163 |
+
|
164 |
+
|
165 |
+
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
166 |
+
# Resize and pad image while meeting stride-multiple constraints
|
167 |
+
shape = im.shape[:2] # current shape [height, width]
|
168 |
+
if isinstance(new_shape, int):
|
169 |
+
new_shape = (new_shape, new_shape)
|
170 |
+
|
171 |
+
# Scale ratio (new / old)
|
172 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
173 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
174 |
+
r = min(r, 1.0)
|
175 |
+
|
176 |
+
# Compute padding
|
177 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
178 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
179 |
+
|
180 |
+
if auto: # minimum rectangle
|
181 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
182 |
+
|
183 |
+
dw /= 2 # divide padding into 2 sides
|
184 |
+
dh /= 2
|
185 |
+
|
186 |
+
if shape[::-1] != new_unpad: # resize
|
187 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
188 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
189 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
190 |
+
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
191 |
+
return im, r, (dw, dh)
|
192 |
+
|
193 |
+
|
194 |
+
def get_layout_results(img, onnx_path):
|
195 |
+
providers = ['CPUExecutionProvider']
|
196 |
+
session = ort.InferenceSession(onnx_path, providers=providers)
|
197 |
+
names = ['Articles', 'Advertisement', 'Headlines', 'Sub-headlines', 'Graphics', 'Images', 'Tables', 'Text Block', 'Header']
|
198 |
+
# colors = {name:[random.randint(0, 255) for _ in range(3)] for i,name in enumerate(names)}
|
199 |
+
# instead of random color, use specific easily distinguishable colors for each class
|
200 |
+
colors = {
|
201 |
+
'Articles': [255, 0, 0], # Red
|
202 |
+
'Advertisement': [0, 255, 0], # Green
|
203 |
+
'Headlines': [0, 0, 255], # Blue
|
204 |
+
'Sub-headlines': [255, 255, 0], # Yellow
|
205 |
+
'Graphics': [255, 0, 255], # Magenta
|
206 |
+
'Images': [128, 0, 128], # Purple
|
207 |
+
'Tables': [0, 255, 255], # Teal
|
208 |
+
'Text Block': [0, 128, 128], # Navy
|
209 |
+
'Header': [0, 0, 0] # Black
|
210 |
+
}
|
211 |
+
|
212 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
213 |
+
image = img.copy()
|
214 |
+
image, ratio, dwdh = letterbox(image, auto=False)
|
215 |
+
image = image.transpose((2, 0, 1))
|
216 |
+
image = np.expand_dims(image, 0)
|
217 |
+
image = np.ascontiguousarray(image)
|
218 |
+
im = image.astype(np.float32)
|
219 |
+
im /= 255.0
|
220 |
+
outname = [i.name for i in session.get_outputs()]
|
221 |
+
inname = [i.name for i in session.get_inputs()]
|
222 |
+
inp = {inname[0]:im}
|
223 |
+
|
224 |
+
# ONNX inference
|
225 |
+
outputs = session.run(outname, inp)[0]
|
226 |
+
# convert to torch tensor
|
227 |
+
outputs = torch.from_numpy(outputs)
|
228 |
+
det = non_max_suppression(outputs, 0.25, 0.45, classes=None, agnostic=False)[0] # conf_thres=0.25, iou_thres=0.45
|
229 |
+
results = []
|
230 |
+
# postprocess the output
|
231 |
+
for i,(x0,y0,x1,y1,score,cls_id) in enumerate(det):
|
232 |
+
box = np.array([x0,y0,x1,y1])
|
233 |
+
box -= np.array(dwdh*2)
|
234 |
+
box /= ratio
|
235 |
+
box = box.round().astype(np.int32).tolist()
|
236 |
+
cls_id = int(cls_id)
|
237 |
+
score = round(float(score),3)
|
238 |
+
name = names[cls_id]
|
239 |
+
color = colors[name]
|
240 |
+
results.append([box, score, cls_id, color])
|
241 |
+
|
242 |
+
return results
|
243 |
+
|
244 |
+
if __name__ == '__main__':
|
245 |
+
onnx_path = "/home/ubuntu/devesh/yolov7/runs/train/yolov7-custom9/weights/best.onnx"
|
246 |
+
img_ori = cv2.imread('/home/ubuntu/devesh/yolov7/Language_wise_imgs/Hindi/_Dainik_Navajyoti_-_04-11-2023_3.png')
|
247 |
+
lines = get_layout_results(img_ori, onnx_path)
|
248 |
+
print(lines[0])
|
runtime.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python-3.8.7
|