Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
import pandas as pd | |
class LayoutLM: | |
def __init__(self, save_pretrained_fpath:str=None) -> None: | |
self.pipeline_category = 'document-question-answering' | |
self.tf_pipeline = pipeline | |
self.pipeline = None | |
if save_pretrained_fpath is not None: | |
pipe = self.tf_pipeline(self.pipeline_category) | |
pipe.save_pretrained(save_pretrained_fpath) | |
self.default_model = 'impira/layoutlm-invoices' | |
self.default_ex_answer = {'score':0, 'answer':'-'} | |
def set_model(self, model:str): | |
if model is None: | |
model = self.default_model | |
self.pipeline = self.tf_pipeline(self.pipeline_category, model=model) | |
def answer_the_question_without_filter(self, img, question: str, is_debug=False, **kwargs): | |
answers = None | |
top_k = kwargs['top_k'] if kwargs.get('top_k') is not None else 1 | |
max_answer_len = kwargs['max_answer_len'] if kwargs.get('max_answer_len') is not None else 15 | |
if self.pipeline is not None: | |
answers = self.pipeline(img, question, | |
top_k=top_k, | |
max_answer_len=max_answer_len) | |
if is_debug: | |
print('--------------------') | |
print(answers) | |
return answers | |
def answer_the_question(self, img, question: str, is_debug=False): | |
score = 0 | |
answer = '-' | |
answers = None | |
if self.pipeline is not None: | |
answers = self.pipeline(img, question) | |
for a in answers: | |
if a['score'] > score: | |
score = a['score'] | |
answer = a['answer'] | |
if is_debug: | |
print('--------------------') | |
print(f'Q: {question}\nA: {answer} (acc:{score:.2f})\n') | |
print(answers) | |
return answer | |
def inference(self, img, is_debug=False): | |
merchant_id = self.answer_the_question(img, 'What is merchant ID?', is_debug=is_debug) | |
merchant_name = self.answer_the_question(img, 'What is merchant name?', is_debug=is_debug) | |
merchant_address = self.answer_the_question(img, 'What is merchant address?', is_debug=is_debug) | |
merchant_branch = self.answer_the_question(img, 'What is branch of merchant?', is_debug=is_debug) | |
invoice_no = self.answer_the_question(img, 'What is invoice number or INV?', is_debug=is_debug) | |
products = self.answer_the_question(img, 'What are buy products?', is_debug=is_debug) | |
product_codes = self.answer_the_question(img, 'What are code of buy products?', is_debug=is_debug) | |
pos_no = self.answer_the_question(img, 'What is POS number?', is_debug=is_debug) | |
net_price = self.answer_the_question(img, 'What is the net-price?', is_debug=is_debug) | |
date_time = self.answer_the_question(img, 'What date, year and time of the invoice?', is_debug=is_debug) | |
if is_debug: | |
print(f'Merchant ID: {merchant_id}') | |
print(f'Merchant name: {merchant_name}') | |
print(f'Merchant address: {merchant_address}') | |
print(f'Merchant branch: {merchant_branch}') | |
print(f'Invoice no.: {invoice_no}') | |
print(f'Products: {products}') | |
print(f'Product codes: {product_codes}') | |
print(f'POS no.: {pos_no}') | |
print(f'Net price: {net_price}') | |
print(f'Date/Time: {date_time}') | |
return pd.DataFrame({ | |
'Data' : [ | |
'Merchant ID', | |
'Merchant name', | |
'Merchant address', | |
'Merchant branch', | |
'Invoice no.', | |
'Products', | |
'Product codes', | |
'POS no.', | |
'Net price', | |
'Date/Time' | |
], | |
'Value' : [ | |
str(merchant_id), | |
str(merchant_name), | |
str(merchant_address), | |
str(merchant_branch), | |
str(invoice_no), | |
str(products), | |
str(product_codes), | |
str(pos_no), | |
str(net_price), | |
str(date_time) | |
] | |
}) | |