from typing import Tuple, List, Sequence, Optional, Union from pathlib import Path import re import torch import tokenizers as tk from PIL import Image from matplotlib import pyplot as plt from matplotlib import patches from torchvision import transforms from torch import nn, Tensor from functools import partial import numpy.typing as npt from numpy import uint8 ImageType = npt.NDArray[uint8] import warnings import time import argparse from bs4 import BeautifulSoup as bs from .src.model import EncoderDecoder, ImgLinearBackbone, Encoder, Decoder from .src.utils import subsequent_mask, pred_token_within_range, greedy_sampling, bbox_str_to_token_list, html_str_to_token_list,cell_str_to_token_list, build_table_from_html_and_cell, html_table_template from .src.trainer.utils import VALID_HTML_TOKEN, VALID_BBOX_TOKEN, INVALID_CELL_TOKEN """ ImgLinearBackbone, Encoder, Decoder are in components.py EncoderDecoder is in encoderdecoder.py """ warnings.filterwarnings('ignore') class UnitableFullPredictor(): def __init__(self): pass def load_vocab_and_model( self, backbone, encoder, decoder, vocab_path: Union[str, Path], max_seq_len: int, model_weights: Union[str, Path], ) -> Tuple[tk.Tokenizer, EncoderDecoder]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vocab = tk.Tokenizer.from_file(vocab_path) d_model = 768 dropout = 0.2 model = EncoderDecoder( backbone= backbone, encoder= encoder, decoder= decoder, vocab_size= vocab.get_vocab_size(), d_model= d_model, padding_idx= vocab.token_to_id(""), max_seq_len=max_seq_len, dropout=dropout, norm_layer=partial(nn.LayerNorm, eps=1e-6) ) # it loads weights onto the CPU first and then moves the model to the desired device model.load_state_dict(torch.load(model_weights, map_location="cpu")) model = model.to(device) return vocab, model def autoregressive_decode( self, model: EncoderDecoder, image: Tensor, prefix: Sequence[int], max_decode_len: int, eos_id: int, token_whitelist: Optional[Sequence[int]] = None, token_blacklist: Optional[Sequence[int]] = None, ) -> Tensor: model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): """ The encoder takes the input data (in this case, an image) and transforms it into a high-dimensional feature representation. This feature representation, or memory tensor, captures the essential information from the input data needed to generate the output sequence. """ memory = model.encode(image) """ Creates a context tensor from the prefix and repeats it to match the batch size of the image, moving it to the appropriate device. """ context = torch.tensor(prefix, dtype=torch.int32).repeat(image.shape[0], 1).to(device) for _ in range(max_decode_len): eos_flag = [eos_id in k for k in context] if all(eos_flag): break with torch.no_grad(): causal_mask = subsequent_mask(context.shape[1]).to(device) logits = model.decode( memory, context, tgt_mask=causal_mask, tgt_padding_mask=None ) logits = model.generator(logits)[:, -1, :] logits = pred_token_within_range( logits.detach(), white_list=token_whitelist, black_list=token_blacklist, ) next_probs, next_tokens = greedy_sampling(logits) context = torch.cat([context, next_tokens], dim=1) return context @staticmethod def image_to_tensor(image: Image, size: Tuple[int, int]) -> Tensor: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") T = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize(mean=[0.86597056,0.88463002,0.87491087], std = [0.20686628,0.18201602,0.18485524]) ]) image_tensor = T(image) image_tensor = image_tensor.to(device).unsqueeze(0) return image_tensor def rescale_bbox( self, bbox: Sequence[Sequence[float]], src: Tuple[int, int], tgt: Tuple[int, int] ) -> Sequence[Sequence[float]]: assert len(src) == len(tgt) == 2 ratio = [tgt[0] / src[0], tgt[1] / src[1]] * 2 print(ratio) bbox = [[int(round(i * j)) for i, j in zip(entry, ratio)] for entry in bbox] return bbox def predict(self, images:List[Image.Image],debugfolder_filename_page_name:str): MODEL_FILE_NAME = ["unitable_large_structure.pt", "unitable_large_bbox.pt", "unitable_large_content.pt"] MODEL_DIR = Path("./unitable/experiments/unitable_weights") # UniTable large model d_model = 768 patch_size = 16 nhead = 12 dropout = 0.2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") backbone= ImgLinearBackbone(d_model=d_model, patch_size=patch_size) encoder= Encoder( d_model=d_model, nhead=nhead, dropout=dropout, activation="gelu", norm_first=True, nlayer=12, ff_ratio=4, ) decoder= Decoder( d_model=d_model, nhead=nhead, dropout=dropout, activation="gelu", norm_first=True, nlayer=4, ff_ratio=4, ) print("Running table transformer + Unitable Full Model") """ Step 1 Load Table Structure Model """ start1 = time.time() # Table structure extraction vocabS, modelS = self.load_vocab_and_model( backbone=backbone, encoder=encoder, decoder=decoder, vocab_path="./unitable/vocab/vocab_html.json", max_seq_len=784, model_weights=MODEL_DIR / MODEL_FILE_NAME[0] ) end1 = time.time() print("time to load table structure model ",end1-start1,"seconds") """ Step 2 prepare images to tensor """ image_tensors = [] for i in range(len(images)): image_size = images[i].size # Image transformation image_tensor = self.image_to_tensor(images[i], (448, 448)) image_tensors.append(image_tensor) print("Check if image_tensors is what i want it to be ") print(type(image_tensors)) # This will be list of arrays(pred_html), which is again list of array pred_htmls = [] for i in range(len(image_tensors)): #print(image_tensor) print("Processing table "+str(i)) start2 = time.time() # Inference pred_html = self.autoregressive_decode( model= modelS, image= image_tensors[i], prefix=[vocabS.token_to_id("[html]")], max_decode_len=512, eos_id=vocabS.token_to_id(""), token_whitelist=[vocabS.token_to_id(i) for i in VALID_HTML_TOKEN], token_blacklist = None ) end2 = time.time() print("time for inference table structure ",end2-start2,"seconds") pred_html = pred_html.detach().cpu().numpy()[0] pred_html = vocabS.decode(pred_html, skip_special_tokens=False) pred_html = html_str_to_token_list(pred_html) pred_htmls.append(pred_html) print(pred_html) """ Step 3 Load Table Cell detection """ start3 = time.time() # Table cell bbox detection vocabB, modelB = self.load_vocab_and_model( backbone=backbone, encoder=encoder, decoder=decoder, vocab_path="./unitable/vocab/vocab_bbox.json", max_seq_len=1024, model_weights=MODEL_DIR / MODEL_FILE_NAME[1], ) end3 = time.time() print("time to load cell bbox detection model ",end3-start3,"seconds") """ Step 4 do the pred_bboxes detection """ pred_bboxs =[] for i in range(len(image_tensors)): start4 = time.time() # Inference pred_bbox = self.autoregressive_decode( model=modelB, image=image_tensors[i], prefix=[vocabB.token_to_id("[bbox]")], max_decode_len=1024, eos_id=vocabB.token_to_id(""), token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], token_blacklist = None ) end4 = time.time() print("Processing table "+str(i)) print("time to do inference for table cell bbox detection model ",end4-start4,"seconds") # Convert token id to token text pred_bbox = pred_bbox.detach().cpu().numpy()[0] pred_bbox = vocabB.decode(pred_bbox, skip_special_tokens=False) pred_bbox = bbox_str_to_token_list(pred_bbox) pred_bbox = self.rescale_bbox(pred_bbox, src=(448, 448), tgt=images[i].size) print(pred_bbox) print("Size of the image ") #(1498, 971) print(images[i].size) print("Number of bounding boxes ") print(len(pred_bbox)) countcells = 0 for elem in pred_htmls[i] : if elem == '[]' or elem == '>[]': countcells+=1 #275 print("number of countcells") print(countcells) if countcells > 256: #TODO Extra processing for big tables #Find the last incomplete row and its ymax coordinate # Last bbox's ymax gives us coordinate of where the cutted off row starts #IMPORTANT : pred_bbox is xmin, ymin, xmax, ymax cut_off = pred_bbox[-1][1] #This will be used to distinguish how many cells are already detected in that row. last_cells_redudant = 0 for cell in reversed(pred_bbox): if cut_off-5 < cell[1] "), token_whitelist=[vocabB.token_to_id(i) for i in VALID_BBOX_TOKEN[: 449]], token_blacklist = None ) # Convert token id to token text pred_bbox_extra = pred_bbox_extra.detach().cpu().numpy()[0] pred_bbox_extra = vocabB.decode(pred_bbox_extra, skip_special_tokens=False) pred_bbox_extra = bbox_str_to_token_list(pred_bbox_extra) pred_bbox_extra = pred_bbox_extra[last_cells_redudant-1:] pred_bbox_extra = self.rescale_bbox(pred_bbox_extra, src=(448, 448), tgt=cropped_image.size) pred_bbox_extra = [[i[0], i[1]+cut_off, i[2], i[3]+cut_off] for i in pred_bbox_extra] pred_bbox = pred_bbox + pred_bbox_extra print("extra boxes:") print(pred_bbox_extra) print(len(pred_bbox_extra)) pred_bboxs.append(pred_bbox) fig, ax = plt.subplots(figsize=(12, 10)) for j in pred_bbox: #i is xmin, ymin, xmax, ymax based on the function usage rect = patches.Rectangle(j[:2], j[2] - j[0], j[3] - j[1], linewidth=1, edgecolor='r', facecolor='none') ax.add_patch(rect) ax.set_axis_off() ax.imshow(images[i]) fig.savefig(debugfolder_filename_page_name+str(i)+".png", bbox_inches='tight', dpi=300) """ Step 5 : Load table cell recognition contents """ start4 = time.time() # Table cell bbox detection vocabC, modelC = self.load_vocab_and_model( backbone=backbone, encoder=encoder, decoder=decoder, vocab_path="./unitable/vocab/vocab_cell_6k.json", max_seq_len=200, model_weights=MODEL_DIR / MODEL_FILE_NAME[2], ) end4 = time.time() print("time to load cell recognition model ",end4-start4,"seconds") pred_cells = [] """ Step 6 : Decode for all tables """ for i in range(len(images)): cell_image_tensors_for_img =[] for bbox in pred_bboxs[i]: cropped_img = images[i].crop(bbox) if cropped_img.size[0] >0: cell_image_tensors_for_img.append(self.image_to_tensor(cropped_img, size=(112, 448))) cell_image_tensors_for_img = torch.cat(cell_image_tensors_for_img, dim=0).to(device) #print("size of tensor") #print(image_tensor.size()) start4 = time.time() # Inference pred_cell = self.autoregressive_decode( model=modelC, image=cell_image_tensors_for_img, prefix=[vocabC.token_to_id("[cell]")], max_decode_len=200, eos_id=vocabC.token_to_id(""), token_whitelist=None, token_blacklist = [vocabC.token_to_id(i) for i in INVALID_CELL_TOKEN] ) # Convert token id to token text pred_cell = pred_cell.detach().cpu().numpy() pred_cell = vocabC.decode_batch(pred_cell, skip_special_tokens=False) end4 = time.time() print("Processing table "+str(i)) print("time to do cell recognition ",end4-start4,"seconds") pred_cell = [cell_str_to_token_list(i) for i in pred_cell] #The code finds instances in each string of pred_cell where there is a digit followed by any character and then whitespace followed by another digit. #It replaces these instances with the first digit, followed by a period, followed by the second digit, effectively removing the whitespace and any character between the digits and replacing it with a period. pred_cell = [re.sub(r'(\d).\s+(\d)', r'\1.\2', i) for i in pred_cell] print(pred_cell) pred_cells.append(pred_cell) print(type(pred_cells)) table_codes =[] for pred_html, pred_cell in zip(pred_htmls, pred_cells): # Combine the table structure and cell content pred_code = build_table_from_html_and_cell(pred_html, pred_cell) pred_code = "".join(pred_code) pred_code = html_table_template(pred_code) # Display the HTML table soup = bs(pred_code) table_code = soup.prettify() print(table_code) table_codes.append(table_code) return table_codes