invex / apis /layoutlm.py
patharanor's picture
feat: host list file detection
42a09ad verified
from transformers import pipeline
import pandas as pd
class LayoutLM:
def __init__(self, save_pretrained_fpath:str=None) -> None:
self.pipeline_category = 'document-question-answering'
self.tf_pipeline = pipeline
self.pipeline = None
if save_pretrained_fpath is not None:
pipe = self.tf_pipeline(self.pipeline_category)
pipe.save_pretrained(save_pretrained_fpath)
self.default_model = 'impira/layoutlm-invoices'
self.default_ex_answer = {'score':0, 'answer':'-'}
def set_model(self, model:str):
if model is None:
model = self.default_model
self.pipeline = self.tf_pipeline(self.pipeline_category, model=model)
def answer_the_question_without_filter(self, img, question: str, is_debug=False, **kwargs):
answers = None
top_k = kwargs['top_k'] if kwargs.get('top_k') is not None else 1
max_answer_len = kwargs['max_answer_len'] if kwargs.get('max_answer_len') is not None else 15
if self.pipeline is not None:
answers = self.pipeline(img, question,
top_k=top_k,
max_answer_len=max_answer_len)
if is_debug:
print('--------------------')
print(answers)
return answers
def answer_the_question(self, img, question: str, is_debug=False):
score = 0
answer = '-'
answers = None
if self.pipeline is not None:
answers = self.pipeline(img, question)
for a in answers:
if a['score'] > score:
score = a['score']
answer = a['answer']
if is_debug:
print('--------------------')
print(f'Q: {question}\nA: {answer} (acc:{score:.2f})\n')
print(answers)
return answer
def inference(self, img, is_debug=False):
merchant_id = self.answer_the_question(img, 'What is merchant ID?', is_debug=is_debug)
merchant_name = self.answer_the_question(img, 'What is merchant name?', is_debug=is_debug)
merchant_address = self.answer_the_question(img, 'What is merchant address?', is_debug=is_debug)
merchant_branch = self.answer_the_question(img, 'What is branch of merchant?', is_debug=is_debug)
invoice_no = self.answer_the_question(img, 'What is invoice number or INV?', is_debug=is_debug)
products = self.answer_the_question(img, 'What are buy products?', is_debug=is_debug)
product_codes = self.answer_the_question(img, 'What are code of buy products?', is_debug=is_debug)
pos_no = self.answer_the_question(img, 'What is POS number?', is_debug=is_debug)
net_price = self.answer_the_question(img, 'What is the net-price?', is_debug=is_debug)
date_time = self.answer_the_question(img, 'What date, year and time of the invoice?', is_debug=is_debug)
if is_debug:
print(f'Merchant ID: {merchant_id}')
print(f'Merchant name: {merchant_name}')
print(f'Merchant address: {merchant_address}')
print(f'Merchant branch: {merchant_branch}')
print(f'Invoice no.: {invoice_no}')
print(f'Products: {products}')
print(f'Product codes: {product_codes}')
print(f'POS no.: {pos_no}')
print(f'Net price: {net_price}')
print(f'Date/Time: {date_time}')
return pd.DataFrame({
'Data' : [
'Merchant ID',
'Merchant name',
'Merchant address',
'Merchant branch',
'Invoice no.',
'Products',
'Product codes',
'POS no.',
'Net price',
'Date/Time'
],
'Value' : [
str(merchant_id),
str(merchant_name),
str(merchant_address),
str(merchant_branch),
str(invoice_no),
str(products),
str(product_codes),
str(pos_no),
str(net_price),
str(date_time)
]
})