Spaces:
Build error
Build error
from typing import Tuple, List, Sequence, Optional, Union | |
from pathlib import Path | |
import re | |
import torch | |
import tokenizers as tk | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
from matplotlib import patches | |
from torchvision import transforms | |
from torch import nn, Tensor | |
from functools import partial | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
import warnings | |
import time | |
import argparse | |
from bs4 import BeautifulSoup as bs | |
from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder | |
from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list,cell_str_to_token_list, build_table_from_html_and_cell, html_table_template | |
from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN | |
""" | |
ImgLinearBackbone, Encoder, Decoder are in components.py | |
EncoderDecoder is in encoderdecoder.py | |
""" | |
warnings.filterwarnings('ignore') | |
class UnitableFullPredictor(): | |
def __init__(self): | |
pass | |
def load_vocab_and_model( | |
self, | |
backbone, | |
encoder, | |
decoder, | |
vocab_path: Union[str, Path], | |
max_seq_len: int, | |
model_weights: Union[str, Path], | |
) -> Tuple[tk.Tokenizer, EncoderDecoder]: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vocab = tk.Tokenizer.from_file(vocab_path) | |
d_model = 768 | |
dropout = 0.2 | |
model = EncoderDecoder( | |
backbone= backbone, | |
encoder= encoder, | |
decoder= decoder, | |
vocab_size= vocab.get_vocab_size(), | |
d_model= d_model, | |
padding_idx= vocab.token_to_id("<pad>"), | |
max_seq_len=max_seq_len, | |
dropout=dropout, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
) | |
# it loads weights onto the CPU first and then moves the model to the desired device | |
model.load_state_dict(torch.load(model_weights, map_location="cpu")) | |
model = model.to(device) | |
return vocab, model | |
def autoregressive_decode( | |
self, | |
model: EncoderDecoder, | |
image: Tensor, | |
prefix: Sequence[int], | |
max_decode_len: int, | |
eos_id: int, | |
token_whitelist: Optional[Sequence[int]] = None, | |
token_blacklist: Optional[Sequence[int]] = None, | |
) -> Tensor: | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(): | |
""" | |
The encoder takes the input data (in this case, an image) and transforms it into a high-dimensional feature representation. | |
This feature representation, or memory tensor, captures the essential information from the input data needed to generate the output sequence. | |
""" | |
memory = model.encode(image) | |
""" | |
Creates a context tensor from the prefix and repeats it to match the batch size of the image, moving it to the appropriate device. | |
""" | |
context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(device) | |
for _ in range(max_decode_len): | |
eos_flag = [eos_id in k for k in context] | |
if all(eos_flag): | |
break | |
with torch.no_grad(): | |
causal_mask = subsequent_mask(context.shape[1]).to(device) | |
logits = model.decode( | |
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None | |
) | |
logits = model.generator(logits)[:, -1, :] | |
logits = pred_token_within_range( | |
logits.detach(), | |
white_list=token_whitelist, | |
black_list=token_blacklist, | |
) | |
next_probs, next_tokens = greedy_sampling(logits) | |
context = torch.cat([context, next_tokens], dim=1) | |
return context | |
def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
T = transforms.Compose([ | |
transforms.Resize(size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524]) | |
]) | |
image_tensor = T(image) | |
image_tensor = image_tensor.to(device).unsqueeze(0) | |
return image_tensor | |
def rescale_bbox( | |
self, | |
bbox: Sequence[Sequence[float]], | |
src: Tuple[int, int], | |
tgt: Tuple[int, int] | |
) -> Sequence[Sequence[float]]: | |
assert len(src) == len(tgt) == 2 | |
ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2 | |
print(ratio) | |
bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox] | |
return bbox | |
def predict(self, images:List[Image.Image],debugfolder_filename_page_name:str): | |
MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"] | |
MODEL_DIR = Path("./unitable/experiments/unitable_weights") | |
# UniTable large model | |
d_model = 768 | |
patch_size = 16 | |
nhead = 12 | |
dropout = 0.2 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
backbone= ImgLinearBackbone(d_model=d_model, patch_size=patch_size) | |
encoder= Encoder( | |
d_model=d_model, | |
nhead=nhead, | |
dropout=dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=12, | |
ff_ratio=4, | |
) | |
decoder= Decoder( | |
d_model=d_model, | |
nhead=nhead, | |
dropout=dropout, | |
activation="gelu", | |
norm_first=True, | |
nlayer=4, | |
ff_ratio=4, | |
) | |
print("Running table transformer + Unitable Full Model") | |
""" | |
Step 1 Load Table Structure Model | |
""" | |
start1 = time.time() | |
# Table structure extraction | |
vocabS, modelS = self.load_vocab_and_model( | |
backbone=backbone, | |
encoder=encoder, | |
decoder=decoder, | |
vocab_path="./unitable/vocab/vocab_html.json", | |
max_seq_len=784, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[0] | |
) | |
end1 = time.time() | |
print("time to load table structure model ",end1-start1,"seconds") | |
""" | |
Step 2 prepare images to tensor | |
""" | |
image_tensors = [] | |
for i in range(len(images)): | |
image_size = images[i].size | |
# Image transformation | |
image_tensor = self.image_to_tensor(images[i], (448, 448)) | |
image_tensors.append(image_tensor) | |
print("Check if image_tensors is what i want it to be ") | |
print(type(image_tensors)) | |
# This will be list of arrays(pred_html), which is again list of array | |
pred_htmls = [] | |
for i in range(len(image_tensors)): | |
#print(image_tensor) | |
print("Processing table "+str(i)) | |
start2 = time.time() | |
# Inference | |
pred_html = self.autoregressive_decode( | |
model= modelS, | |
image= image_tensors[i], | |
prefix=[vocabS.token_to_id("[html]")], | |
max_decode_len=512, | |
eos_id=vocabS.token_to_id("<eos>"), | |
token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN], | |
token_blacklist = None | |
) | |
end2 = time.time() | |
print("time for inference table structure ",end2-start2,"seconds") | |
pred_html = pred_html.detach().cpu().numpy()[0] | |
pred_html = vocabS.decode(pred_html, skip_special_tokens=False) | |
pred_html = html_str_to_token_list(pred_html) | |
pred_htmls.append(pred_html) | |
print(pred_html) | |
""" | |
Step 3 Load Table Cell detection | |
""" | |
start3 = time.time() | |
# Table cell bbox detection | |
vocabB, modelB = self.load_vocab_and_model( | |
backbone=backbone, | |
encoder=encoder, | |
decoder=decoder, | |
vocab_path="./unitable/vocab/vocab_bbox.json", | |
max_seq_len=1024, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[1], | |
) | |
end3 = time.time() | |
print("time to load cell bbox detection model ",end3-start3,"seconds") | |
""" | |
Step 4 do the pred_bboxes detection | |
""" | |
pred_bboxs =[] | |
for i in range(len(image_tensors)): | |
start4 = time.time() | |
# Inference | |
pred_bbox = self.autoregressive_decode( | |
model=modelB, | |
image=image_tensors[i], | |
prefix=[vocabB.token_to_id("[bbox]")], | |
max_decode_len=1024, | |
eos_id=vocabB.token_to_id("<eos>"), | |
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], | |
token_blacklist = None | |
) | |
end4 = time.time() | |
print("Processing table "+str(i)) | |
print("time to do inference for table cell bbox detection model ",end4-start4,"seconds") | |
# Convert token id to token text | |
pred_bbox = pred_bbox.detach().cpu().numpy()[0] | |
pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False) | |
pred_bbox = bbox_str_to_token_list(pred_bbox) | |
pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=images[i].size) | |
print(pred_bbox) | |
print("Size of the image ") | |
#(1498, 971) | |
print(images[i].size) | |
print("Number of bounding boxes ") | |
print(len(pred_bbox)) | |
countcells = 0 | |
for elem in pred_htmls[i] : | |
if elem == '<td>[]</td>' or elem == '>[]</td>': | |
countcells+=1 | |
#275 | |
print("number of countcells") | |
print(countcells) | |
if countcells > 256: | |
#TODO Extra processing for big tables | |
#Find the last incomplete row and its ymax coordinate | |
# Last bbox's ymax gives us coordinate of where the cutted off row starts | |
#IMPORTANT : pred_bbox is xmin, ymin, xmax, ymax | |
cut_off = pred_bbox[-1][1] | |
#This will be used to distinguish how many cells are already detected in that row. | |
last_cells_redudant = 0 | |
for cell in reversed(pred_bbox): | |
if cut_off-5 < cell[1] <cut_off+5: | |
last_cells_redudant+=1 | |
else: | |
break | |
width = images[i].size[0] | |
height = images[i].size[1] | |
#IMPORTANT : crop takes in (xmin, ymax, xmax, ymin) coordintes !!! | |
bbox = (0, cut_off, width, height) | |
# Crop the image to the specified bounding box | |
cropped_image = images[i].crop(bbox) | |
#cropped_image.save("./res/table_debug/cropped_image_for_extra_bbox_det_table_num_"+str(i)+".png") | |
image_tensor = self.image_to_tensor(cropped_image, (448, 448)) | |
pred_bbox_extra = self.autoregressive_decode( | |
model=modelB, | |
image=image_tensor, | |
prefix=[vocabB.token_to_id("[bbox]")], | |
max_decode_len=1024, | |
eos_id=vocabB.token_to_id("<eos>"), | |
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], | |
token_blacklist = None | |
) | |
# Convert token id to token text | |
pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0] | |
pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False) | |
pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra) | |
pred_bbox_extra = pred_bbox_extra[last_cells_redudant-1:] | |
pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size) | |
pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra] | |
pred_bbox = pred_bbox + pred_bbox_extra | |
print("extra boxes:") | |
print(pred_bbox_extra) | |
print(len(pred_bbox_extra)) | |
pred_bboxs.append(pred_bbox) | |
fig, ax = plt.subplots(figsize=(12, 10)) | |
for j in pred_bbox: | |
#i is xmin, ymin, xmax, ymax based on the function usage | |
rect = patches.Rectangle(j[:2], j[2] - j[0], j[3] - j[1], linewidth=1, edgecolor='r', facecolor='none') | |
ax.add_patch(rect) | |
ax.set_axis_off() | |
ax.imshow(images[i]) | |
fig.savefig(debugfolder_filename_page_name+str(i)+".png", bbox_inches='tight', dpi=300) | |
""" | |
Step 5 : Load table cell recognition contents | |
""" | |
start4 = time.time() | |
# Table cell bbox detection | |
vocabC, modelC = self.load_vocab_and_model( | |
backbone=backbone, | |
encoder=encoder, | |
decoder=decoder, | |
vocab_path="./unitable/vocab/vocab_cell_6k.json", | |
max_seq_len=200, | |
model_weights=MODEL_DIR / MODEL_FILE_NAME[2], | |
) | |
end4 = time.time() | |
print("time to load cell recognition model ",end4-start4,"seconds") | |
pred_cells = [] | |
""" | |
Step 6 : Decode for all tables | |
""" | |
for i in range(len(images)): | |
cell_image_tensors_for_img =[] | |
for bbox in pred_bboxs[i]: | |
cropped_img = images[i].crop(bbox) | |
if cropped_img.size[0] >0: | |
cell_image_tensors_for_img.append(self.image_to_tensor(cropped_img, size=(112, 448))) | |
cell_image_tensors_for_img = torch.cat(cell_image_tensors_for_img, dim=0).to(device) | |
#print("size of tensor") | |
#print(image_tensor.size()) | |
start4 = time.time() | |
# Inference | |
pred_cell = self.autoregressive_decode( | |
model=modelC, | |
image=cell_image_tensors_for_img, | |
prefix=[vocabC.token_to_id("[cell]")], | |
max_decode_len=200, | |
eos_id=vocabC.token_to_id("<eos>"), | |
token_whitelist=None, | |
token_blacklist = [vocabC.token_to_id(i) for i in INVALID_CELL_TOKEN] | |
) | |
# Convert token id to token text | |
pred_cell = pred_cell.detach().cpu().numpy() | |
pred_cell = vocabC.decode_batch(pred_cell, skip_special_tokens=False) | |
end4 = time.time() | |
print("Processing table "+str(i)) | |
print("time to do cell recognition ",end4-start4,"seconds") | |
pred_cell = [cell_str_to_token_list(i) for i in pred_cell] | |
#The code finds instances in each string of pred_cell where there is a digit followed by any character and then whitespace followed by another digit. | |
#It replaces these instances with the first digit, followed by a period, followed by the second digit, effectively removing the whitespace and any character between the digits and replacing it with a period. | |
pred_cell = [re.sub(r'(\d).\s+(\d)', r'\1.\2', i) for i in pred_cell] | |
print(pred_cell) | |
pred_cells.append(pred_cell) | |
print(type(pred_cells)) | |
table_codes =[] | |
for pred_html, pred_cell in zip(pred_htmls, pred_cells): | |
# Combine the table structure and cell content | |
pred_code = build_table_from_html_and_cell(pred_html, pred_cell) | |
pred_code = "".join(pred_code) | |
pred_code = html_table_template(pred_code) | |
# Display the HTML table | |
soup = bs(pred_code) | |
table_code = soup.prettify() | |
print(table_code) | |
table_codes.append(table_code) | |
return table_codes | |