alps / ocrTable2.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
from typing import Tuple, List, Sequence, Optional, Union
from torchvision import transforms
from torch import nn, Tensor
from PIL import Image
from pathlib import Path
from bs4 import BeautifulSoup as bs
from unitable import UnitableFullPredictor
from unitable import UnitableFullSinglePredictor
from doctrfiles import DoctrWordDetector,DoctrTextRecognizer
import numpy as np
from utils import crop_an_Image,cropImageExtraMargin
import numpy.typing as npt
from numpy import uint8
ImageType = npt.NDArray[uint8]
class OcrTable2():
#Takes as input the table image - no table detection
def __init__(self):
self.unitablePredictor = UnitableFullPredictor()
#self.unitablePredictor = UnitableFullSinglePredictor()
@staticmethod
def save_detection(detected_lines_images:List[ImageType], prefix = './res/test1/res_'):
i = 0
for img in detected_lines_images:
pilimg = Image.fromarray(img)
pilimg.save(prefix+str(i)+'.png')
i=i+1
def predict(self,images,debug_repo="./res/test1"):
# Step 1: Get table structure and bbox for cell contents from unitable
table_code = self.unitablePredictor.predict(images,debug_repo)