Spaces:
Build error
Build error
from typing import Tuple, List, Sequence, Optional, Union | |
from pathlib import Path | |
import re | |
import torch | |
import tokenizers as tk | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
from matplotlib import patches | |
from torchvision import transforms | |
from torch import nn, Tensor | |
from functools import partial | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
import warnings | |
import time | |
import argparse | |
from bs4 import BeautifulSoup as bs | |
from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder | |
from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list,cell_str_to_token_list, build_table_from_html_and_cell, html_table_template | |
from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN | |
warnings.filterwarnings('ignore') | |
class UnitableFullSinglePredictor(): | |
def __init__(self): | |
MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"] | |
MODEL_DIR = Path("unitable/experiments/unitable_weights") | |
# UniTable large model | |
self.d_model = 768 | |
self.patch_size = 16 | |
self.nhead = 12 | |
self.dropout = 0.2 | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.backbone= ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size) | |
self.encoder= Encoder( | |
d_model=self.d_model, | |
nhead=self.nhead, | |
dropout=self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=12, | |
ff_ratio=4, | |
) | |
self.decoder= Decoder( | |
d_model=self.d_model, | |
nhead=self.nhead, | |
dropout=self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=4, | |
ff_ratio=4, | |
) | |
""" | |
start1 = time.time() | |
# Table structure extraction | |
self.vocabS, self.modelS = self.load_vocab_and_model( | |
backbone= ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size), | |
encoder= Encoder( | |
d_model=self.d_model, | |
nhead=self.nhead, | |
dropout=self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=12, | |
ff_ratio=4, | |
), | |
decoder= Decoder( | |
d_model=self.d_model, | |
nhead=self.nhead, | |
dropout=self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=4, | |
ff_ratio=4, | |
), | |
d_model= self.d_model, | |
dropout= self.dropout, | |
vocab_path="unitable/vocab/vocab_html.json", | |
max_seq_len=784, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[0] | |
) | |
end1 = time.time() | |
print("time to load table structure model ",end1-start1,"seconds") | |
start3 = time.time() | |
# Table cell bbox detection | |
self.vocabB, self.modelB = self.load_vocab_and_model( | |
backbone = ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size), | |
encoder = Encoder( | |
d_model= self.d_model, | |
nhead= self.nhead, | |
dropout = self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=12, | |
ff_ratio=4, | |
), | |
decoder = Decoder( | |
d_model= self.d_model, | |
nhead= self.nhead, | |
dropout = self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=4, | |
ff_ratio=4, | |
), | |
d_model= self.d_model, | |
dropout= self.dropout, | |
vocab_path="unitable/vocab/vocab_bbox.json", | |
max_seq_len=1024, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[1], | |
) | |
end3 = time.time() | |
print("time to load cell bbox detection model ",end3-start3,"seconds") | |
start4 = time.time() | |
# Table cell bbox detection | |
self.vocabC, self.modelC = self.load_vocab_and_model( | |
backbone = ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size), | |
encoder = Encoder( | |
d_model= self.d_model, | |
nhead= self.nhead, | |
dropout = self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=12, | |
ff_ratio=4, | |
), | |
decoder = Decoder( | |
d_model= self.d_model, | |
nhead= self.nhead, | |
dropout = self.dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=4, | |
ff_ratio=4, | |
), | |
d_model= self.d_model, | |
dropout= self.dropout, | |
vocab_path="unitable/vocab/vocab_cell_6k.json", | |
max_seq_len=200, | |
#Using the content recognition model i guess | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[2], | |
) | |
end4 = time.time() | |
print("time to load cell recognition model ",end4-start4,"seconds") | |
""" | |
def load_vocab_and_model( | |
self, | |
vocab_path: Union[str, Path], | |
max_seq_len: int, | |
model_weights: Union[str, Path], | |
) -> Tuple[tk.Tokenizer, EncoderDecoder]: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vocab = tk.Tokenizer.from_file(vocab_path) | |
model = EncoderDecoder( | |
backbone= self.backbone, | |
encoder= self.encoder, | |
decoder= self.decoder, | |
vocab_size= vocab.get_vocab_size(), | |
d_model= self.d_model, | |
padding_idx= vocab.token_to_id("<pad>"), | |
max_seq_len=max_seq_len, | |
dropout=self.dropout, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
) | |
# it loads weights onto the CPU first and then moves the model to the desired device | |
model.load_state_dict(torch.load(model_weights, map_location="cpu")) | |
model = model.to(device) | |
return vocab, model | |
def autoregressive_decode( | |
self, | |
model: EncoderDecoder, | |
image: Tensor, | |
prefix: Sequence[int], | |
max_decode_len: int, | |
eos_id: int, | |
token_whitelist: Optional[Sequence[int]] = None, | |
token_blacklist: Optional[Sequence[int]] = None, | |
) -> Tensor: | |
model.eval() | |
with torch.no_grad(): | |
memory = model.encode(image) | |
context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(self.device) | |
for _ in range(max_decode_len): | |
eos_flag = [eos_id in k for k in context] | |
if all(eos_flag): | |
break | |
with torch.no_grad(): | |
causal_mask = subsequent_mask(context.shape[1]).to(self.device) | |
logits = model.decode( | |
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None | |
) | |
logits = model.generator(logits)[:, -1, :] | |
logits = pred_token_within_range( | |
logits.detach(), | |
white_list=token_whitelist, | |
black_list=token_blacklist, | |
) | |
next_probs, next_tokens = greedy_sampling(logits) | |
context = torch.cat([context, next_tokens], dim=1) | |
return context | |
def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Resize the image with padding | |
#resized_image = UnitableFullPredictor.resize_with_padding(image, size) | |
T = transforms.Compose([ | |
transforms.Resize(size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524]) | |
]) | |
image_tensor = T(image) | |
image_tensor = image_tensor.to(device).unsqueeze(0) | |
return image_tensor | |
""" | |
@staticmethod | |
def resize_with_padding(image: Image, target_size: Tuple[int, int]) -> Image: | |
#Resize the image to fit within the target size while preserving aspect ratio, | |
#then add padding to match the target size. | |
original_width, original_height = image.size | |
target_width, target_height = target_size | |
# Calculate the new size preserving aspect ratio | |
aspect_ratio = original_width / original_height | |
if target_width / target_height > aspect_ratio: | |
new_height = target_height | |
new_width = int(new_height * aspect_ratio) | |
else: | |
new_width = target_width | |
new_height = int(new_width / aspect_ratio) | |
# Resize the image to the new size | |
resized_image = image.resize((new_width, new_height),Image.LANCZOS) | |
# Create a new image with white background | |
new_image = Image.new("RGB", (target_width, target_height), (255, 255, 255)) | |
# Paste the resized image onto the white background | |
paste_position = ((target_width - new_width) // 2, (target_height - new_height) // 2) | |
new_image.paste(resized_image, paste_position) | |
new_image.save("../res/table_resize_with_padding.png") | |
return new_image | |
""" | |
def rescale_bbox( | |
self, | |
bbox: Sequence[Sequence[float]], | |
src: Tuple[int, int], | |
tgt: Tuple[int, int] | |
) -> Sequence[Sequence[float]]: | |
assert len(src) == len(tgt) == 2 | |
ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2 | |
print(ratio) | |
bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox] | |
return bbox | |
""" | |
@staticmethod | |
def rescale_bbox( | |
bbox: Sequence[Sequence[float]], | |
src: Tuple[int, int], | |
tgt: Tuple[int, int] | |
) -> Sequence[Sequence[float]]: | |
#Rescale bounding boxes according to the transformation applied in resize_with_padding. | |
src_width, src_height = src | |
tgt_width, tgt_height = tgt | |
# Calculate the new size preserving aspect ratio | |
aspect_ratio = src_width / src_height | |
if tgt_width / tgt_height > aspect_ratio: | |
new_height = tgt_height | |
new_width = int(new_height * aspect_ratio) | |
else: | |
new_width = tgt_width | |
new_height = int(new_width / aspect_ratio) | |
# Calculate the scale factors | |
#THIS *2 factor was done in their code - why ? i have no clue | |
scale_x = (new_width / src_width ) * 2 | |
scale_y = (new_height / src_height) *2 | |
# Calculate the padding | |
pad_x = (tgt_width - new_width) // 2 | |
pad_y = (tgt_height - new_height) // 2 | |
# Rescale and adjust the bounding boxes | |
rescaled_bbox = [] | |
for entry in bbox: | |
x_min = int(round(entry[0] * scale_x -pad_x)) | |
y_min = int(round(entry[1] * scale_y - pad_y)) | |
x_max = int(round(entry[2] * scale_x - pad_x)) | |
y_max = int(round(entry[3] * scale_y - pad_y)) | |
rescaled_bbox.append([x_min, y_min, x_max, y_max]) | |
return rescaled_bbox | |
""" | |
def predict(self, image:ImageType): | |
MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"] | |
MODEL_DIR = Path("unitable/experiments/unitable_weights") | |
image_size = image.size | |
print("RUNING SINGLE IMAGE UNITABLE FOR DEBUGGGING ") | |
# Image transformation | |
image_tensor = self.image_to_tensor(image, (448, 448)) | |
#print(image_tensor) | |
""" | |
Step 1 Table Structure recognition | |
""" | |
start1 = time.time() | |
# Table structure extraction | |
vocabS, modelS = self.load_vocab_and_model( | |
vocab_path="unitable/vocab/vocab_html.json", | |
max_seq_len=784, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[0] | |
) | |
end1 = time.time() | |
print("time to load table structure model ",end1-start1,"seconds") | |
start2 = time.time() | |
# Inference | |
pred_html = self.autoregressive_decode( | |
model= modelS, | |
image= image_tensor, | |
prefix=[vocabS.token_to_id("[html]")], | |
max_decode_len=512, | |
eos_id=vocabS.token_to_id("<eos>"), | |
token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN], | |
token_blacklist = None | |
) | |
end2 = time.time() | |
print("time for inference table structure ",end2-start2,"seconds") | |
# Convert token id to token text | |
pred_html = pred_html.detach().cpu().numpy()[0] | |
pred_html = vocabS.decode(pred_html, skip_special_tokens=False) | |
#print(pred_html) | |
pred_html = html_str_to_token_list(pred_html) | |
print(pred_html) | |
""" | |
Step 2 Table Cell detection | |
""" | |
start3 = time.time() | |
# Table cell bbox detection | |
vocabB, modelB = self.load_vocab_and_model( | |
vocab_path="unitable/vocab/vocab_bbox.json", | |
max_seq_len=1024, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[1], | |
) | |
end3 = time.time() | |
print("time to load cell bbox detection model ",end3-start3,"seconds") | |
start4 = time.time() | |
# Inference | |
pred_bbox = self.autoregressive_decode( | |
model=modelB, | |
image=image_tensor, | |
prefix=[vocabB.token_to_id("[bbox]")], | |
max_decode_len=1024, | |
eos_id=vocabB.token_to_id("<eos>"), | |
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], | |
token_blacklist = None | |
) | |
end4 = time.time() | |
print("time to do inference for table cell bbox detection model ",end4-start4,"seconds") | |
# Convert token id to token text | |
pred_bbox = pred_bbox.detach().cpu().numpy()[0] | |
pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False) | |
pred_bbox = bbox_str_to_token_list(pred_bbox) | |
pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=image.size) | |
print(pred_bbox) | |
print("Size of the image ") | |
#(1498, 971) | |
print(image.size) | |
print("Number of bounding boxes ") | |
print(len(pred_bbox)) | |
countcells = 0 | |
#startBody = False | |
#startFirstRow = True | |
#numElemInRow = 0 | |
for elem in pred_html : | |
#if elem == '<tbody>': | |
# startBody = True | |
#elif startBody ==True and elem == '<tr>': | |
# startFirstRow = True | |
#elif startFirstRow == True and elem == '<td>[]</td>': | |
# numElemInRow +=1 | |
#elif startBody ==True and elem == '</tr>': | |
# startFirstRow = False | |
# startBody = False | |
if elem == '<td>[]</td>': | |
countcells+=1 | |
#275 | |
print(countcells) | |
if countcells > len(pred_bbox): | |
#TODO Extra processing for big tables | |
#Find the last incomplete row and its ymax coordinate | |
# Last bbox's ymax gives us coordinate of where the cutted off row starts | |
#IMPORTANT : pred_bbox is xmin, ymax, xmax, ymin | |
cut_off = pred_bbox[-1][1] | |
width = image.size[0] | |
height = image.size[1] | |
#bbox = (0, cut_off, width, height) | |
#IMPORTANT : crop takes in (xmin, ymax, xmax, ymin) coordintes !!! | |
bbox = (0, cut_off, width, height) | |
# Crop the image to the specified bounding box | |
cropped_image = image.crop(bbox) | |
cropped_image.save("./res/cropped_image_for_extra_bbox_det.png") | |
image_tensor = self.image_to_tensor(cropped_image, (448, 448)) | |
pred_bbox_extra = self.autoregressive_decode( | |
model=modelB, | |
image=image_tensor, | |
prefix=[vocabB.token_to_id("[bbox]")], | |
max_decode_len=1024, | |
eos_id=vocabB.token_to_id("<eos>"), | |
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], | |
token_blacklist = None | |
) | |
# Convert token id to token text | |
pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0] | |
pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False) | |
pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra) | |
numberOrCellsToAdd = countcells-len(pred_bbox) | |
pred_bbox_extra = pred_bbox_extra[-numberOrCellsToAdd:] | |
pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size) | |
#This resulted in table_bbox_test_extra_3.png | |
#pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra] | |
pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra] | |
pred_bbox = pred_bbox + pred_bbox_extra | |
#[[25, 63, 152, 86], [227, 63, 292, 86], [326, 63, 373, 86], [413, 63, 460, 86], [562, 63, 609, 86], [708, 63, 758, 86], [848, 63, 895, 86], [935, 63, 982, 86], [1025, 63, 1075, 86], [1119, 63, 1165, 86], [1280, 63, 1327, 86]] | |
print(pred_bbox_extra) | |
#11 | |
print(len(pred_bbox_extra)) | |
fig, ax = plt.subplots(figsize=(12, 10)) | |
for i in pred_bbox: | |
#i is xmin, ymin, xmax, ymax based on the function usage | |
rect = patches.Rectangle(i[:2], i[2] - i[0], i[3] - i[1], linewidth=1, edgecolor='r', facecolor='none') | |
ax.add_patch(rect) | |
ax.set_axis_off() | |
ax.imshow(image) | |
fig.savefig('./res/table_debug3/singleimageres.png', bbox_inches='tight', dpi=300) | |
""" | |
Step 3 : Table cell content recognition | |
""" | |
start4 = time.time() | |
# Table cell bbox detection | |
vocabC, modelC = self.load_vocab_and_model( | |
vocab_path="unitable/vocab/vocab_cell_6k.json", | |
max_seq_len=200, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[2], | |
) | |
end4 = time.time() | |
print("time to load cell recognition model ",end4-start4,"seconds") | |
# Cell image cropping and transformation | |
""" | |
images = [image.crop(bbox) for bbox in pred_bbox] | |
for idx, img in enumerate(images): | |
img.save("res/debug/cell_{}.png".format(idx)) | |
""" | |
#Cropping boundaries are fine | |
image_tensor = [self.image_to_tensor(image.crop(bbox), size=(112, 448)) for bbox in pred_bbox] | |
image_tensor = torch.cat(image_tensor, dim=0) | |
#print("size of tensor") | |
#print(image_tensor.size()) | |
start4 = time.time() | |
# Inference | |
pred_cell = self.autoregressive_decode( | |
model=modelC, | |
image=image_tensor, | |
prefix=[vocabC.token_to_id("[cell]")], | |
max_decode_len=200, | |
eos_id=vocabC.token_to_id("<eos>"), | |
token_whitelist=None, | |
token_blacklist = [vocabC.token_to_id(i) for i in INVALID_CELL_TOKEN] | |
) | |
# Convert token id to token text | |
pred_cell = pred_cell.detach().cpu().numpy() | |
pred_cell = vocabC.decode_batch(pred_cell, skip_special_tokens=False) | |
end4 = time.time() | |
print("time to do cell recognition ",end4-start4,"seconds") | |
pred_cell = [cell_str_to_token_list(i) for i in pred_cell] | |
pred_cell = [re.sub(r'(\d).\s+(\d)', r'\1.\2', i) for i in pred_cell] | |
print(pred_cell) | |
# Combine the table structure and cell content | |
pred_code = build_table_from_html_and_cell(pred_html, pred_cell) | |
pred_code = "".join(pred_code) | |
pred_code = html_table_template(pred_code) | |
# Display the HTML table | |
soup = bs(pred_code) | |
table_code = soup.prettify() | |
print(table_code) | |