Spaces:
Build error
Build error
from typing import Tuple, List, Sequence, Optional, Union | |
from torchvision import transforms | |
from torch import nn, Tensor | |
from PIL import Image | |
from pathlib import Path | |
from bs4 import BeautifulSoup as bs | |
import numpy as np | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
from transformers import AutoModelForObjectDetection | |
import torch | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from matplotlib.patches import Patch | |
from unitable import UnitablePredictor | |
from doctrfiles import DoctrWordDetector,DoctrTextRecognizer | |
from utils import crop_an_Image,cropImageExtraMargin | |
from utils import denoisingAndSharpening | |
#based on this notebook:https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Inference_with_Table_Transformer_(TATR)_for_parsing_tables.ipynb | |
class MaxResize(object): | |
def __init__(self, max_size=800): | |
self.max_size = max_size | |
def __call__(self, image): | |
width, height = image.size | |
current_max_size = max(width, height) | |
scale = self.max_size / current_max_size | |
resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) | |
return resized_image | |
html_table_template = ( | |
lambda table: f"""<html> | |
<head> <meta charset="UTF-8"> | |
<style> | |
table, th, td {{ | |
border: 1px solid black; | |
font-size: 10px; | |
}} | |
</style> </head> | |
<body> | |
<table frame="hsides" rules="groups" width="100%%"> | |
{table} | |
</table> </body> </html>""" | |
) | |
class DetectionAndOcrTable1(): | |
def __init__(self,englishFlag=True): | |
self.unitablePredictor = UnitablePredictor() | |
self.wordDetector = DoctrWordDetector(architecture="db_resnet50", | |
path_weights="doctrfiles/models/db_resnet50-79bd7d70.pt", | |
path_config_json ="doctrfiles/models/db_resnet50_config.json") | |
if englishFlag: | |
self.textRecognizer = DoctrTextRecognizer(architecture="master", path_weights="./doctrfiles/models/master-fde31e4a.pt", | |
path_config_json="./doctrfiles/models/master.json") | |
else: | |
self.textRecognizer = DoctrTextRecognizer(architecture="parseq", path_weights="./doctrfiles/models/doctr-multilingual-parseq.bin", | |
path_config_json="./doctrfiles/models/multilingual-parseq-config.json") | |
def build_table_from_html_and_cell( | |
structure: List[str], content: List[str] = None | |
) -> List[str]: | |
"""Build table from html and cell token list""" | |
assert structure is not None | |
html_code = list() | |
# deal with empty table | |
if content is None: | |
content = ["placeholder"] * len(structure) | |
for tag in structure: | |
if tag in ("<td>[]</td>", ">[]</td>"): | |
if len(content) == 0: | |
continue | |
cell = content.pop(0) | |
html_code.append(tag.replace("[]", cell)) | |
else: | |
html_code.append(tag) | |
return html_code | |
def save_detection(detected_lines_images:List[ImageType], prefix = './res/test1/res_'): | |
i = 0 | |
for img in detected_lines_images: | |
pilimg = Image.fromarray(img) | |
pilimg.save(prefix+str(i)+'.png') | |
i=i+1 | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(-1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = DetectionAndOcrTable1.box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def outputs_to_objects(outputs, img_size, id2label): | |
m = outputs.logits.softmax(-1).max(-1) | |
pred_labels = list(m.indices.detach().cpu().numpy())[0] | |
pred_scores = list(m.values.detach().cpu().numpy())[0] | |
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] | |
pred_bboxes = [elem.tolist() for elem in DetectionAndOcrTable1.rescale_bboxes(pred_bboxes, img_size)] | |
objects = [] | |
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): | |
class_label = id2label[int(label)] | |
if not class_label == 'no object': | |
objects.append({'label': class_label, 'score': float(score), | |
'bbox': [float(elem) for elem in bbox]}) | |
return objects | |
def fig2img(fig): | |
"""Convert a Matplotlib figure to a PIL Image and return it""" | |
import io | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
#For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
def objects_to_crops(img, tokens, objects, class_thresholds, padding=10): | |
""" | |
Process the bounding boxes produced by the table detection model into | |
cropped table images and cropped tokens. | |
""" | |
table_crops = [] | |
for obj in objects: | |
# abit unecessary here cause i crop them anywyas | |
if obj['score'] < class_thresholds[obj['label']]: | |
continue | |
cropped_table = {} | |
bbox = obj['bbox'] | |
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding] | |
cropped_img = img.crop(bbox) | |
# Add padding to the cropped image | |
padded_width = cropped_img.width + 40 | |
padded_height = cropped_img.height +40 | |
new_img_np = np.full((padded_height, padded_width, 3), fill_value=255, dtype=np.uint8) | |
y_offset = (padded_height - cropped_img.height) // 2 | |
x_offset = (padded_width - cropped_img.width) // 2 | |
new_img_np[y_offset:y_offset + cropped_img.height, x_offset:x_offset+cropped_img.width] = np.array(cropped_img) | |
padded_img = Image.fromarray(new_img_np,'RGB') | |
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] | |
for token in table_tokens: | |
token['bbox'] = [token['bbox'][0]-bbox[0] + padding, | |
token['bbox'][1]-bbox[1] + padding, | |
token['bbox'][2]-bbox[0] + padding, | |
token['bbox'][3]-bbox[1] + padding] | |
# If table is predicted to be rotated, rotate cropped image and tokens/words: | |
if obj['label'] == 'table rotated': | |
padded_img = padded_img.rotate(270, expand=True) | |
for token in table_tokens: | |
bbox = token['bbox'] | |
bbox = [padded_img.size[0]-bbox[3]-1, | |
bbox[0], | |
padded_img.size[0]-bbox[1]-1, | |
bbox[2]] | |
token['bbox'] = bbox | |
cropped_table['image'] = padded_img | |
cropped_table['tokens'] = table_tokens | |
table_crops.append(cropped_table) | |
return table_crops | |
def visualize_detected_tables(img, det_tables, out_path=None): | |
plt.imshow(img, interpolation="lanczos") | |
fig = plt.gcf() | |
fig.set_size_inches(20, 20) | |
ax = plt.gca() | |
for det_table in det_tables: | |
bbox = det_table['bbox'] | |
if det_table['label'] == 'table': | |
facecolor = (1, 0, 0.45) | |
edgecolor = (1, 0, 0.45) | |
alpha = 0.3 | |
linewidth = 2 | |
hatch='//////' | |
elif det_table['label'] == 'table rotated': | |
facecolor = (0.95, 0.6, 0.1) | |
edgecolor = (0.95, 0.6, 0.1) | |
alpha = 0.3 | |
linewidth = 2 | |
hatch='//////' | |
else: | |
continue | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
edgecolor='none',facecolor=facecolor, alpha=0.1) | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha) | |
ax.add_patch(rect) | |
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, | |
edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) | |
ax.add_patch(rect) | |
plt.xticks([], []) | |
plt.yticks([], []) | |
legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), | |
label='Table', hatch='//////', alpha=0.3), | |
Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), | |
label='Table (rotated)', hatch='//////', alpha=0.3)] | |
plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, | |
fontsize=10, ncol=2) | |
plt.gcf().set_size_inches(10, 10) | |
plt.axis('off') | |
if out_path is not None: | |
plt.savefig(out_path, bbox_inches='tight', dpi=150) | |
return fig | |
def predict(self,image:Image.Image,debugfolder_filename_page_name,denoise=False): | |
""" | |
0. Locate the table using Table detection | |
1. Unitable | |
""" | |
print("Running table transformer + Unitable Hybrid Model") | |
# Step 0 : Locate the table using Table detection TODO | |
#First we load a Table Transformer pre-trained for table detection. We use the "no_timm" version here to load the checkpoint with a Transformers-native backbone. | |
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
#Preparing the image for the model | |
detection_transform = transforms.Compose([ | |
MaxResize(800), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
pixel_values = detection_transform(image).unsqueeze(0) | |
pixel_values = pixel_values.to(device) | |
# Next, we forward the pixel values through the model. | |
# The model outputs logits of shape (batch_size, num_queries, num_labels + 1). The +1 is for the "no object" class. | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
# update id2label to include "no object" | |
id2label = model.config.id2label | |
id2label[len(model.config.id2label)] = "no object" | |
#[{'label': 'table', 'score': 0.9999570846557617, 'bbox': [110.24547576904297, 73.31171417236328, 1024.609130859375, 308.7159423828125]}] | |
objects = DetectionAndOcrTable1.outputs_to_objects(outputs, image.size, id2label) | |
#Only do these for objects with score greater than 0.8 | |
objects = [obj for obj in objects if obj['score'] > 0.95] | |
print("detected object from the table transformers are") | |
print(objects) | |
if objects: | |
#Next, we crop the table out of the image. For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
tokens = [] | |
detection_class_thresholds = { | |
"table": 0.95, #this is a bit double cause we do up there another filtering but didn't want to modify too much from original code | |
"table rotated": 0.95, | |
"no object": 10 | |
} | |
crop_padding = 10 | |
tables_crops = DetectionAndOcrTable1.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding) | |
cropped_tables =[] | |
for i in range (len(tables_crops)): | |
cropped_table = tables_crops[i]['image'].convert("RGB") | |
cropped_table.save(debugfolder_filename_page_name+"cropped_table_"+str(i)+".png") | |
cropped_tables.append(cropped_table) | |
# Step 1: Unitable | |
#This take PIL Images as input | |
if denoise: | |
cropped_tables =denoisingAndSharpening(cropped_tables) | |
pred_htmls, pred_bboxs = self.unitablePredictor.predict(cropped_tables,debugfolder_filename_page_name) | |
table_codes = [] | |
for k in range(len(cropped_tables)): | |
pred_html =pred_htmls[k] | |
pred_bbox = pred_bboxs[k] | |
# Some tabless have a lot of words in their header | |
# So for the headers, give doctr word ddetector doesn't work when the images aren't square | |
table_header_cells = 0 | |
header_exists = False | |
for cell in pred_html: | |
if cell=='>[]</td>' or cell == '<td>[]</td>': | |
table_header_cells += 1 | |
if cell =='</thead>': | |
header_exists = True | |
break | |
if not header_exists: | |
table_header_cells = 0 | |
pred_cell = [] | |
cell_imgs_to_viz = [] | |
cell_img_num=0 | |
# Find what one line should be if there is a cell with a single line | |
one_line_height = 100000 | |
for i in range(table_header_cells): | |
box = pred_bbox[i] | |
xmin, ymin, xmax, ymax = box | |
current_box_height = abs(ymax-ymin) | |
if current_box_height<one_line_height: | |
one_line_height = current_box_height | |
for box in pred_bbox: | |
xmin, ymin, xmax, ymax = box | |
fourbytwo = np.array([ | |
[xmin, ymin], | |
[xmax, ymin], | |
[xmax, ymax], | |
[xmin, ymax] | |
], dtype=np.float32) | |
current_box_height = abs(ymax-ymin) | |
# Those are for header cells with more than one line | |
if table_header_cells > 0 and current_box_height>one_line_height+5: | |
cell_img= cropImageExtraMargin([fourbytwo],cropped_tables[k],margin=1.4)[0] | |
table_header_cells -= 1 | |
#List of 4 x 2 | |
detection_results = self.wordDetector.predict(cell_img,sort_vertical=True) | |
input_to_recog = [] | |
if detection_results == []: | |
input_to_recog.append(cell_img) | |
else: | |
for wordbox in detection_results: | |
cropped_image= crop_an_Image(wordbox.box,cell_img) | |
if cropped_image.shape[0] >0 and cropped_image.shape[1]>0: | |
input_to_recog.append(cropped_image) | |
else: | |
print("Empty image") | |
else: | |
cell_img = crop_an_Image(fourbytwo,cropped_tables[k]) | |
if table_header_cells>0: | |
table_header_cells -= 1 | |
if cell_img.shape[0] >0 and cell_img.shape[1]>0: | |
input_to_recog =[cell_img] | |
cell_imgs_to_viz.append(cell_img) | |
if input_to_recog != []: | |
words = self.textRecognizer.predict_for_tables(input_to_recog) | |
cell_output = " ".join(words) | |
pred_cell.append(cell_output) | |
else: | |
#Don't lose empty cell | |
pred_cell.append("") | |
print(pred_cell) | |
#Step3 : | |
pred_code = self.build_table_from_html_and_cell(pred_html, pred_cell) | |
pred_code = "".join(pred_code) | |
pred_code = html_table_template(pred_code) | |
soup = bs(pred_code) | |
#formatted and indented) string representation of the HTML document | |
table_code = soup.prettify() | |
print(table_code) | |
# Append extracted table to table_codes | |
table_codes.append(table_code) | |
return table_codes | |