alps / unitable /unitable_full.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Tuple, List, Sequence, Optional, Union
from pathlib import Path
import re
import torch
import tokenizers as tk
from PIL import Image
from matplotlib import pyplot as plt
from matplotlib import patches
from torchvision import transforms
from torch import nn, Tensor
from functools import partial
import numpy.typing as npt
from numpy import uint8
ImageType = npt.NDArray[uint8]
import warnings
import time
import argparse
from bs4 import BeautifulSoup as bs
from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder
from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list,cell_str_to_token_list, build_table_from_html_and_cell, html_table_template
from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN
"""
ImgLinearBackbone, Encoder, Decoder are in components.py
EncoderDecoder is in encoderdecoder.py
"""
warnings.filterwarnings('ignore')
class UnitableFullPredictor():
def __init__(self):
pass
def load_vocab_and_model(
self,
backbone,
encoder,
decoder,
vocab_path: Union[str, Path],
max_seq_len: int,
model_weights: Union[str, Path],
) -> Tuple[tk.Tokenizer, EncoderDecoder]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab = tk.Tokenizer.from_file(vocab_path)
d_model = 768
dropout = 0.2
model = EncoderDecoder(
backbone= backbone,
encoder= encoder,
decoder= decoder,
vocab_size= vocab.get_vocab_size(),
d_model= d_model,
padding_idx= vocab.token_to_id("<pad>"),
max_seq_len=max_seq_len,
dropout=dropout,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
# it loads weights onto the CPU first and then moves the model to the desired device
model.load_state_dict(torch.load(model_weights, map_location="cpu"))
model = model.to(device)
return vocab, model
def autoregressive_decode(
self,
model: EncoderDecoder,
image: Tensor,
prefix: Sequence[int],
max_decode_len: int,
eos_id: int,
token_whitelist: Optional[Sequence[int]] = None,
token_blacklist: Optional[Sequence[int]] = None,
) -> Tensor:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
"""
The encoder takes the input data (in this case, an image) and transforms it into a high-dimensional feature representation.
This feature representation, or memory tensor, captures the essential information from the input data needed to generate the output sequence.
"""
memory = model.encode(image)
"""
Creates a context tensor from the prefix and repeats it to match the batch size of the image, moving it to the appropriate device.
"""
context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(device)
for _ in range(max_decode_len):
eos_flag = [eos_id in k for k in context]
if all(eos_flag):
break
with torch.no_grad():
causal_mask = subsequent_mask(context.shape[1]).to(device)
logits = model.decode(
memory, context, tgt_mask=causal_mask, tgt_padding_mask=None
)
logits = model.generator(logits)[:, -1, :]
logits = pred_token_within_range(
logits.detach(),
white_list=token_whitelist,
black_list=token_blacklist,
)
next_probs, next_tokens = greedy_sampling(logits)
context = torch.cat([context, next_tokens], dim=1)
return context
@staticmethod
def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524])
])
image_tensor = T(image)
image_tensor = image_tensor.to(device).unsqueeze(0)
return image_tensor
def rescale_bbox(
self,
bbox: Sequence[Sequence[float]],
src: Tuple[int, int],
tgt: Tuple[int, int]
) -> Sequence[Sequence[float]]:
assert len(src) == len(tgt) == 2
ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2
print(ratio)
bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox]
return bbox
def predict(self, images:List[Image.Image],debugfolder_filename_page_name:str):
MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"]
MODEL_DIR = Path("./unitable/experiments/unitable_weights")
# UniTable large model
d_model = 768
patch_size = 16
nhead = 12
dropout = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone= ImgLinearBackbone(d_model=d_model, patch_size=patch_size)
encoder= Encoder(
d_model=d_model,
nhead=nhead,
dropout=dropout,
activation="gelu",
norm_first=True,
nlayer=12,
ff_ratio=4,
)
decoder= Decoder(
d_model=d_model,
nhead=nhead,
dropout=dropout,
activation="gelu",
norm_first=True,
nlayer=4,
ff_ratio=4,
)
print("Running table transformer + Unitable Full Model")
"""
Step 1 Load Table Structure Model
"""
start1 = time.time()
# Table structure extraction
vocabS, modelS = self.load_vocab_and_model(
backbone=backbone,
encoder=encoder,
decoder=decoder,
vocab_path="./unitable/vocab/vocab_html.json",
max_seq_len=784,
model_weights=MODEL_DIR / MODEL_FILE_NAME[0]
)
end1 = time.time()
print("time to load table structure model ",end1-start1,"seconds")
"""
Step 2 prepare images to tensor
"""
image_tensors = []
for i in range(len(images)):
image_size = images[i].size
# Image transformation
image_tensor = self.image_to_tensor(images[i], (448, 448))
image_tensors.append(image_tensor)
print("Check if image_tensors is what i want it to be ")
print(type(image_tensors))
# This will be list of arrays(pred_html), which is again list of array
pred_htmls = []
for i in range(len(image_tensors)):
#print(image_tensor)
print("Processing table "+str(i))
start2 = time.time()
# Inference
pred_html = self.autoregressive_decode(
model= modelS,
image= image_tensors[i],
prefix=[vocabS.token_to_id("[html]")],
max_decode_len=512,
eos_id=vocabS.token_to_id("<eos>"),
token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN],
token_blacklist = None
)
end2 = time.time()
print("time for inference table structure ",end2-start2,"seconds")
pred_html = pred_html.detach().cpu().numpy()[0]
pred_html = vocabS.decode(pred_html, skip_special_tokens=False)
pred_html = html_str_to_token_list(pred_html)
pred_htmls.append(pred_html)
print(pred_html)
"""
Step 3 Load Table Cell detection
"""
start3 = time.time()
# Table cell bbox detection
vocabB, modelB = self.load_vocab_and_model(
backbone=backbone,
encoder=encoder,
decoder=decoder,
vocab_path="./unitable/vocab/vocab_bbox.json",
max_seq_len=1024,
model_weights=MODEL_DIR / MODEL_FILE_NAME[1],
)
end3 = time.time()
print("time to load cell bbox detection model ",end3-start3,"seconds")
"""
Step 4 do the pred_bboxes detection
"""
pred_bboxs =[]
for i in range(len(image_tensors)):
start4 = time.time()
# Inference
pred_bbox = self.autoregressive_decode(
model=modelB,
image=image_tensors[i],
prefix=[vocabB.token_to_id("[bbox]")],
max_decode_len=1024,
eos_id=vocabB.token_to_id("<eos>"),
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]],
token_blacklist = None
)
end4 = time.time()
print("Processing table "+str(i))
print("time to do inference for table cell bbox detection model ",end4-start4,"seconds")
# Convert token id to token text
pred_bbox = pred_bbox.detach().cpu().numpy()[0]
pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False)
pred_bbox = bbox_str_to_token_list(pred_bbox)
pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=images[i].size)
print(pred_bbox)
print("Size of the image ")
#(1498, 971)
print(images[i].size)
print("Number of bounding boxes ")
print(len(pred_bbox))
countcells = 0
for elem in pred_htmls[i] :
if elem == '<td>[]</td>' or elem == '>[]</td>':
countcells+=1
#275
print("number of countcells")
print(countcells)
if countcells > 256:
#TODO Extra processing for big tables
#Find the last incomplete row and its ymax coordinate
# Last bbox's ymax gives us coordinate of where the cutted off row starts
#IMPORTANT : pred_bbox is xmin, ymin, xmax, ymax
cut_off = pred_bbox[-1][1]
#This will be used to distinguish how many cells are already detected in that row.
last_cells_redudant = 0
for cell in reversed(pred_bbox):
if cut_off-5 < cell[1] <cut_off+5:
last_cells_redudant+=1
else:
break
width = images[i].size[0]
height = images[i].size[1]
#IMPORTANT : crop takes in (xmin, ymax, xmax, ymin) coordintes !!!
bbox = (0, cut_off, width, height)
# Crop the image to the specified bounding box
cropped_image = images[i].crop(bbox)
#cropped_image.save("./res/table_debug/cropped_image_for_extra_bbox_det_table_num_"+str(i)+".png")
image_tensor = self.image_to_tensor(cropped_image, (448, 448))
pred_bbox_extra = self.autoregressive_decode(
model=modelB,
image=image_tensor,
prefix=[vocabB.token_to_id("[bbox]")],
max_decode_len=1024,
eos_id=vocabB.token_to_id("<eos>"),
token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]],
token_blacklist = None
)
# Convert token id to token text
pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0]
pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False)
pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra)
pred_bbox_extra = pred_bbox_extra[last_cells_redudant-1:]
pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size)
pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra]
pred_bbox = pred_bbox + pred_bbox_extra
print("extra boxes:")
print(pred_bbox_extra)
print(len(pred_bbox_extra))
pred_bboxs.append(pred_bbox)
fig, ax = plt.subplots(figsize=(12, 10))
for j in pred_bbox:
#i is xmin, ymin, xmax, ymax based on the function usage
rect = patches.Rectangle(j[:2], j[2] - j[0], j[3] - j[1], linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.set_axis_off()
ax.imshow(images[i])
fig.savefig(debugfolder_filename_page_name+str(i)+".png", bbox_inches='tight', dpi=300)
"""
Step 5 : Load table cell recognition contents
"""
start4 = time.time()
# Table cell bbox detection
vocabC, modelC = self.load_vocab_and_model(
backbone=backbone,
encoder=encoder,
decoder=decoder,
vocab_path="./unitable/vocab/vocab_cell_6k.json",
max_seq_len=200,
model_weights=MODEL_DIR / MODEL_FILE_NAME[2],
)
end4 = time.time()
print("time to load cell recognition model ",end4-start4,"seconds")
pred_cells = []
"""
Step 6 : Decode for all tables
"""
for i in range(len(images)):
cell_image_tensors_for_img =[]
for bbox in pred_bboxs[i]:
cropped_img = images[i].crop(bbox)
if cropped_img.size[0] >0:
cell_image_tensors_for_img.append(self.image_to_tensor(cropped_img, size=(112, 448)))
cell_image_tensors_for_img = torch.cat(cell_image_tensors_for_img, dim=0).to(device)
#print("size of tensor")
#print(image_tensor.size())
start4 = time.time()
# Inference
pred_cell = self.autoregressive_decode(
model=modelC,
image=cell_image_tensors_for_img,
prefix=[vocabC.token_to_id("[cell]")],
max_decode_len=200,
eos_id=vocabC.token_to_id("<eos>"),
token_whitelist=None,
token_blacklist = [vocabC.token_to_id(i) for i in INVALID_CELL_TOKEN]
)
# Convert token id to token text
pred_cell = pred_cell.detach().cpu().numpy()
pred_cell = vocabC.decode_batch(pred_cell, skip_special_tokens=False)
end4 = time.time()
print("Processing table "+str(i))
print("time to do cell recognition ",end4-start4,"seconds")
pred_cell = [cell_str_to_token_list(i) for i in pred_cell]
#The code finds instances in each string of pred_cell where there is a digit followed by any character and then whitespace followed by another digit.
#It replaces these instances with the first digit, followed by a period, followed by the second digit, effectively removing the whitespace and any character between the digits and replacing it with a period.
pred_cell = [re.sub(r'(\d).\s+(\d)', r'\1.\2', i) for i in pred_cell]
print(pred_cell)
pred_cells.append(pred_cell)
print(type(pred_cells))
table_codes =[]
for pred_html, pred_cell in zip(pred_htmls, pred_cells):
# Combine the table structure and cell content
pred_code = build_table_from_html_and_cell(pred_html, pred_cell)
pred_code = "".join(pred_code)
pred_code = html_table_template(pred_code)
# Display the HTML table
soup = bs(pred_code)
table_code = soup.prettify()
print(table_code)
table_codes.append(table_code)
return table_codes