from transformers import pipeline import pandas as pd class LayoutLM: def __init__(self, save_pretrained_fpath:str=None) -> None: self.pipeline_category = 'document-question-answering' self.tf_pipeline = pipeline self.pipeline = None if save_pretrained_fpath is not None: pipe = self.tf_pipeline(self.pipeline_category) pipe.save_pretrained(save_pretrained_fpath) self.default_model = 'impira/layoutlm-invoices' self.default_ex_answer = {'score':0, 'answer':'-'} def set_model(self, model:str): if model is None: model = self.default_model self.pipeline = self.tf_pipeline(self.pipeline_category, model=model) def answer_the_question_without_filter(self, img, question: str, is_debug=False, **kwargs): answers = None top_k = kwargs['top_k'] if kwargs.get('top_k') is not None else 1 max_answer_len = kwargs['max_answer_len'] if kwargs.get('max_answer_len') is not None else 15 if self.pipeline is not None: answers = self.pipeline(img, question, top_k=top_k, max_answer_len=max_answer_len) if is_debug: print('--------------------') print(answers) return answers def answer_the_question(self, img, question: str, is_debug=False): score = 0 answer = '-' answers = None if self.pipeline is not None: answers = self.pipeline(img, question) for a in answers: if a['score'] > score: score = a['score'] answer = a['answer'] if is_debug: print('--------------------') print(f'Q: {question}\nA: {answer} (acc:{score:.2f})\n') print(answers) return answer def inference(self, img, is_debug=False): merchant_id = self.answer_the_question(img, 'What is merchant ID?', is_debug=is_debug) merchant_name = self.answer_the_question(img, 'What is merchant name?', is_debug=is_debug) merchant_address = self.answer_the_question(img, 'What is merchant address?', is_debug=is_debug) merchant_branch = self.answer_the_question(img, 'What is branch of merchant?', is_debug=is_debug) invoice_no = self.answer_the_question(img, 'What is invoice number or INV?', is_debug=is_debug) products = self.answer_the_question(img, 'What are buy products?', is_debug=is_debug) product_codes = self.answer_the_question(img, 'What are code of buy products?', is_debug=is_debug) pos_no = self.answer_the_question(img, 'What is POS number?', is_debug=is_debug) net_price = self.answer_the_question(img, 'What is the net-price?', is_debug=is_debug) date_time = self.answer_the_question(img, 'What date, year and time of the invoice?', is_debug=is_debug) if is_debug: print(f'Merchant ID: {merchant_id}') print(f'Merchant name: {merchant_name}') print(f'Merchant address: {merchant_address}') print(f'Merchant branch: {merchant_branch}') print(f'Invoice no.: {invoice_no}') print(f'Products: {products}') print(f'Product codes: {product_codes}') print(f'POS no.: {pos_no}') print(f'Net price: {net_price}') print(f'Date/Time: {date_time}') return pd.DataFrame({ 'Data' : [ 'Merchant ID', 'Merchant name', 'Merchant address', 'Merchant branch', 'Invoice no.', 'Products', 'Product codes', 'POS no.', 'Net price', 'Date/Time' ], 'Value' : [ str(merchant_id), str(merchant_name), str(merchant_address), str(merchant_branch), str(invoice_no), str(products), str(product_codes), str(pos_no), str(net_price), str(date_time) ] })