File size: 8,463 Bytes
317e1b6 7086666 317e1b6 7086666 317e1b6 7086666 317e1b6 7086666 317e1b6 7086666 317e1b6 7086666 317e1b6 7086666 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import torch
from typing import Any, Optional
from transformers import LayoutLMv2ForQuestionAnswering
from transformers import LayoutLMv2Processor
from transformers import LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2ImageProcessor
from transformers import LayoutLMv2TokenizerFast
from transformers.tokenization_utils_base import BatchEncoding
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import TensorType
from transformers.modeling_outputs import (
QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase
)
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from subprocess import run
import pdf2image
from pprint import pprint
import logging
from os import environ
from dataclasses import dataclass
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
feature_extractor = LayoutLMv2FeatureExtractor()
# @dataclass
# class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase):
# token_logits: Optional[torch.FloatTensor] = None
class NoOCRReaderFound(Exception):
def __init__(self, e):
self.e = e
def __str__(self):
return f"Could not load OCR Reader: {self.e}"
def pdf_to_image(b: bytes):
# First, try to extract text directly
# TODO: This library requires poppler, which is not present everywhere.
# We should look into alternatives. We could also gracefully handle this
# and simply fall back to _only_ extracted text
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)]
encoded_inputs = feature_extractor(images)
print('feature_extractor: ', encoded_inputs.keys())
data = {}
data['image'] = encoded_inputs.pixel_values
data['words'] = encoded_inputs.words
data['boxes'] = encoded_inputs.boxes
return data
def setup_logger(which_logger: Optional[str] = None):
lib_level = logging.DEBUG # Default level for your logger
root_level = logging.INFO
log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s'
logging.basicConfig(
filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'),# taken from loca .env file, not set in settings.py
format=log_format,
datefmt='%d-%b-%y %H:%M:%S',
level=root_level,
force=True
)
log = logging.getLogger(which_logger)
log.setLevel(lib_level)
return log
logger = setup_logger(__name__)
class Funcs:
# helper function to unnormalize bboxes for drawing onto the image
@staticmethod
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
@staticmethod
def num_spans(encoding: BatchEncoding) -> int:
return len(encoding["input_ids"])
@staticmethod
def p_mask(num_spans: int, encoding: BatchEncoding) -> list:
try:
return [
[tok != 1 for tok in encoding.sequence_ids(span_id)] \
for span_id in range(num_spans)
]
except Exception as e:
raise
@staticmethod
def token_start_end(encoding, tokenizer):
sequence_ids = encoding.sequence_ids()
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(encoding.input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
print("Token start index:", token_start_index)
print("Token end index:", token_end_index)
print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1]))
return token_start_index, token_end_index
@staticmethod
def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer):
word_ids = encoding.word_ids()[token_start_index:token_end_index+1]
print("Word ids:", word_ids)
for id in word_ids:
if id == word_idx_start:
start_position = token_start_index
else:
token_start_index += 1
for id in word_ids[::-1]:
if id == word_idx_end:
end_position = token_end_index
else:
token_end_index -= 1
print("Reconstructed answer:",
tokenizer.decode(encoding.input_ids[start_position:end_position+1])
)
return start_position, end_position
@staticmethod
def sigmoid(_outputs):
return 1.0 / (1.0 + np.exp(-_outputs))
@staticmethod
def softmax(_outputs):
maxes = np.max(_outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(_outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class EndpointHandler:
def __init__(self, path="./"):
# self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path).to(device)
self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path)
self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path)
# self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default
self.processor = LayoutLMv2Processor.from_pretrained(
path,
# image_processor=self.image_processor,
tokenizer=self.tokenizer)
def __call__(self, data: dict[str, bytes]):
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
image = data.pop("inputs", data)
# image = pdf_to_image(image)
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
question = "what is the bill date"
with torch.no_grad():
for image in images:
# max_seq_len = min(self.tokenizer.model_max_length, 512)
# doc_stride = min(max_seq_len // 2, 256)
encoding = self.processor(
image,
question,
# max_length=max_seq_len,
# stride=doc_stride,
truncation=True,
# truncation=TruncationStrategy.ONLY_SECOND,
# return_offsets_mapping=True,
# return_token_type_ids=True,
# return_overflowing_tokens=True,
return_tensors=TensorType.PYTORCH
)
print('encoding: ', encoding.keys())
# for k, v in encoding.items():
# encoding[k] = v.to(self.model.device)
# num_spans = Funcs.num_spans(encoding)
# p_mask = Funcs.p_mask(num_spans, encoding)
# offset_mapping = encoding.pop('offset_mapping')
# smaple_mapping = encoding.pop('overflow_to_sample_mapping')
outputs = self.model(**encoding)
# print('model outputs: ', outputs.keys())
start_logits = outputs.start_logits
end_logits = outputs.end_logits
predicted_start_idx = start_logits.argmax(-1).item()
predicted_end_idx = end_logits.argmax(-1).item()
predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
# print('answer: ', predicted_answer)
target_start_index = torch.tensor([7])
target_end_index = torch.tensor([14])
outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
# print(predicted_answer_span_start, predicted_answer_span_end)
logger.info(f'''
START
predicted_start_idx: {predicted_start_idx}
predicted_end_idx: {predicted_end_idx}
---
answer: {predicted_answer}
END''')
return {'data': 'success'}
|