AuditEdge's picture
initial commit
81e13bb
raw
history blame
11.2 kB
from .utils import load_model,load_processor,normalize_box,compare_boxes,adjacent
from .model_base_path import LAYOUTLMV2_BASE_PATH,LAYOUTLMV3_BASE_PATH
from .annotate_image import get_flattened_output,annotate_image
from PIL import Image,ImageDraw, ImageFont
import logging
import torch
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger(__name__)
class ModelHandler(object):
"""
A base Model handler implementation.
"""
def __init__(self):
# self.model = None
# self.model_dir = None
# self.device = 'cpu'
# self.error = None
# self._context = None
# self._batch_size = 0
self.initialized = False
self._raw_input_data = None
self._processed_data = None
self._images_size = None
def initialize(self, context,preprocessor,name):
"""
Initialize model. This will be called during model loading time
:param context: Initial context contains model server system properties.
:return:
"""
logger.info("Loading transformer model")
# self._context = context
# properties = self._context
# self._batch_size = properties["batch_size"] or 1
# self.model_dir = properties.get("model_dir")
self.name = name
self.model = context
self.preprocessor = preprocessor
self.initialized = True
def preprocess(self, batch):
"""
Transform raw input into model input data.
:param batch: list of raw requests, should match batch size
:return: list of preprocessed model input data
"""
# Take the input data and pre-process it make it inference ready
# assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch))
inference_dict = batch
print("inference_dict",inference_dict)
self._raw_input_data = inference_dict
# model_name_or_path = None
# if 'v2' in self.model.config.architectures[0]:
# model_name_or_path = LAYOUTLMV2_BASE_PATH
# elif 'v3' in self.model.config.architectures[0]:
# model_name_or_path = LAYOUTLMV3_BASE_PATH
# else:
# raise ValueError('invalid model architecture, please make sure the model is either Layoutlmv2 or Layoutlmv3')
# processor = load_processor(model_name_or_path)
processor = self.preprocessor
images = [Image.open(path).convert("RGB")
for path in inference_dict['image_path']]
self._images_size = [img.size for img in images]
words = inference_dict['words']
boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]
encoded_inputs = processor(
images, words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
self._processed_data = encoded_inputs
encoded_inputs = {key: val.to(device) for key, val in encoded_inputs.items()}
print("encoded_inputs",encoded_inputs)
return encoded_inputs
def load(self, model_dir):
"""The load handler is responsible for loading the hunggingface transformer model.
Returns:
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
"""
# TODO model dir should be microsoft/layoutlmv2-base-uncased
model = load_model(model_dir)
return model
def inference(self, model_input):
"""
Internal inference methods
:param model_input: transformed model input data
:return: list of inference output in NDArray
"""
# TODO load the model state_dict before running the inference
# Do some inference call to engine here and return output
with torch.no_grad():
inference_outputs = self.model(**model_input)
predictions = inference_outputs.logits.argmax(-1).tolist()
print("these are predictions",predictions)
results = []
for i in range(len(predictions)):
tmp = dict()
tmp[f'output_{i}'] = predictions[i]
results.append(tmp)
return [results]
def postprocess(self, inference_output):
print("self._raw_input_data['words']",self._raw_input_data['words'])
print("inference_output",inference_output)
docs = []
k = 0
for page, doc_words in enumerate(self._raw_input_data['words']):
print(page,doc_words)
doc_list = []
width, height = self._images_size[page]
for i, doc_word in enumerate(doc_words, start=0):
word_tagging = None
word_labels = []
word = dict()
word['id'] = k
k += 1
word['text'] = doc_word
word['pageNum'] = page + 1
word['box'] = self._raw_input_data['bboxes'][page][i]
_normalized_box = normalize_box(
self._raw_input_data['bboxes'][page][i], width, height)
for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
if compare_boxes(box, _normalized_box):
if self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]] != 'O':
word_labels.append(
self.model.config.id2label[inference_output[0][page][f'output_{page}'][j]][2:])
else:
word_labels.append('other')
if word_labels != []:
word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
else:
word_tagging = 'other'
word['label'] = word_tagging
word['pageSize'] = {'width': width, 'height': height}
if word['label'] != 'other':
doc_list.append(word)
spans = []
def adjacents(entity): return [
adj for adj in doc_list if adjacent(entity, adj)]
output_test_tmp = doc_list[:]
for entity in doc_list:
if adjacents(entity) == []:
spans.append([entity])
output_test_tmp.remove(entity)
while output_test_tmp != []:
span = [output_test_tmp[0]]
output_test_tmp = output_test_tmp[1:]
while output_test_tmp != [] and adjacent(span[-1], output_test_tmp[0]):
span.append(output_test_tmp[0])
output_test_tmp.remove(output_test_tmp[0])
spans.append(span)
output_spans = []
for span in spans:
if len(span) == 1:
output_span = {"text": span[0]['text'],
"label": span[0]['label'],
"words": [{
'id': span[0]['id'],
'box': span[0]['box'],
'text': span[0]['text']
}],
}
else:
output_span = {"text": ' '.join([entity['text'] for entity in span]),
"label": span[0]['label'],
"words": [{
'id': entity['id'],
'box': entity['box'],
'text': entity['text']
} for entity in span]
}
output_spans.append(output_span)
docs.append({f'output': output_spans})
return [json.dumps(docs, ensure_ascii=False)]
def handle(self, data, context):
"""
Call preprocess, inference and post-process functions
:param data: input data
:param context: mms context
"""
# print("\nmodel_input\n",data)
print("context",context)
model_input = self.preprocess(data)
print("this is model input",model_input)
model_out = self.inference(model_input)
print("\nmodel_output\n",model_out)
inference_out = self.postprocess(model_out)[0]
print("\nprocessed output\n",inference_out)
# with open('LayoutlMV3InferenceOutput.json', 'w') as inf_out:
# inf_out.write(inference_out)
inference_out_list = json.loads(inference_out)
flattened_output_list = get_flattened_output(inference_out_list)
print("flattened_output_list",flattened_output_list)
if self.name == "cheque":
acc_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
print("entered cheque\n\n",flattened_output_list,"\n\n")
result = {"acc_num":acc_num,
"IFSC":IFSC}
if self.name == "aadhar":
aadhar_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
print("entered aadhar\n\n",flattened_output_list,"\n\n")
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
result = {"aadhar_num":aadhar_num}
if self.name == "pan":
pan_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'PAN_VALUE')
print("entered pan\n\n",flattened_output_list,"\n\n")
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
result = {"pan_num":pan_num}
if self.name == "gst":
gstin_num = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'GSTIN')
print("entered gst\n\n",flattened_output_list,"\n\n")
# IFSC = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'IFSC')
result = {"gstin_num":gstin_num}
# if
# an_tokens = "".join(item['text'] for item in flattened_output_list[0]['output'] if item['label'] == 'AN')
#PAN_VALUE
#AN
#IFSC
# print(f"Concatenated AN tokens: {an_tokens}")
# print("this is flattened output",flattened_output_list)
for i, flattened_output in enumerate(flattened_output_list):
annotate_image(data['image_path'][i], flattened_output)
return result
_service = ModelHandler()
def handle(data, context,processor,name):
# if not _service.initialized:
_service.initialize(context,processor,name)
# if data is None:
# return None
return _service.handle(data, context)