File size: 8,463 Bytes
317e1b6
7086666
317e1b6
 
 
 
 
7086666
 
 
 
 
 
 
317e1b6
 
 
 
7086666
 
 
317e1b6
 
 
 
 
 
 
7086666
 
 
 
317e1b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7086666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317e1b6
7086666
 
 
 
 
 
 
 
 
 
317e1b6
 
 
 
 
 
 
 
 
 
 
7086666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import torch
from typing import Any, Optional
from transformers import LayoutLMv2ForQuestionAnswering
from transformers import LayoutLMv2Processor
from transformers import LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2ImageProcessor
from transformers import LayoutLMv2TokenizerFast
from transformers.tokenization_utils_base import BatchEncoding
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import TensorType
from transformers.modeling_outputs import (
    QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase
)
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from subprocess import run
import pdf2image
from pprint import pprint
import logging
from os import environ
from dataclasses import dataclass

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)

feature_extractor = LayoutLMv2FeatureExtractor()

# @dataclass
# class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase):
#     token_logits: Optional[torch.FloatTensor] = None

class NoOCRReaderFound(Exception):
    def __init__(self, e):
        self.e = e

    def __str__(self):
        return f"Could not load OCR Reader: {self.e}"

def pdf_to_image(b: bytes):
    # First, try to extract text directly
    # TODO: This library requires poppler, which is not present everywhere.
    # We should look into alternatives. We could also gracefully handle this
    # and simply fall back to _only_ extracted text
    images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)]
    encoded_inputs = feature_extractor(images)
    print('feature_extractor: ', encoded_inputs.keys())
    data = {}
    data['image'] = encoded_inputs.pixel_values
    data['words'] = encoded_inputs.words
    data['boxes'] = encoded_inputs.boxes
    return data


def setup_logger(which_logger: Optional[str] = None):
    lib_level = logging.DEBUG # Default level for your logger
    root_level = logging.INFO
    log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s'
    logging.basicConfig(
        filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'),# taken from loca .env file, not set in settings.py
        format=log_format,
        datefmt='%d-%b-%y %H:%M:%S',
        level=root_level,
        force=True
    )
    log = logging.getLogger(which_logger)
    log.setLevel(lib_level)
    return log

logger = setup_logger(__name__)


class Funcs:
    # helper function to unnormalize bboxes for drawing onto the image
    @staticmethod
    def unnormalize_box(bbox, width, height):
        return [
            width * (bbox[0] / 1000),
            height * (bbox[1] / 1000),
            width * (bbox[2] / 1000),
            height * (bbox[3] / 1000),
        ]

    @staticmethod
    def num_spans(encoding: BatchEncoding) -> int:
        return len(encoding["input_ids"])

    @staticmethod
    def p_mask(num_spans: int, encoding: BatchEncoding) -> list:
        try:
            return [
                [tok != 1 for tok in encoding.sequence_ids(span_id)] \
                    for span_id in range(num_spans)
            ]
        except Exception as e:
            raise

    @staticmethod
    def token_start_end(encoding, tokenizer):
        sequence_ids = encoding.sequence_ids()

        # Start token index of the current span in the text.
        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1

        # End token index of the current span in the text.
        token_end_index = len(encoding.input_ids) - 1
        while sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        print("Token start index:", token_start_index)
        print("Token end index:", token_end_index)
        print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1]))
        return token_start_index, token_end_index

    @staticmethod
    def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer):
        word_ids = encoding.word_ids()[token_start_index:token_end_index+1]
        print("Word ids:", word_ids)
        for id in word_ids:
            if id == word_idx_start:
                start_position = token_start_index 
            else:
                token_start_index += 1

        for id in word_ids[::-1]:
            if id == word_idx_end:
                end_position = token_end_index 
            else:
                token_end_index -= 1

        print("Reconstructed answer:", 
            tokenizer.decode(encoding.input_ids[start_position:end_position+1])
        )
        return start_position, end_position

    @staticmethod
    def sigmoid(_outputs):
        return 1.0 / (1.0 + np.exp(-_outputs))

    @staticmethod
    def softmax(_outputs):
        maxes = np.max(_outputs, axis=-1, keepdims=True)
        shifted_exp = np.exp(_outputs - maxes)
        return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

class EndpointHandler:
    def __init__(self, path="./"):
        # self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path).to(device)
        self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path)
        self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path)
        # self.image_processor = LayoutLMv2ImageProcessor()  # apply_ocr is set to True by default
        self.processor = LayoutLMv2Processor.from_pretrained(
            path,
            # image_processor=self.image_processor, 
            tokenizer=self.tokenizer)

    def __call__(self, data: dict[str, bytes]):
        """
        Args:
            data (:obj:):
                includes the deserialized image file as PIL.Image
        """
        image = data.pop("inputs", data)

        # image = pdf_to_image(image)
        images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
        question = "what is the bill date"
        with torch.no_grad():
            for image in images:
                # max_seq_len = min(self.tokenizer.model_max_length, 512)
                # doc_stride = min(max_seq_len // 2, 256)
                encoding = self.processor(
                    image, 
                    question, 
                    # max_length=max_seq_len,
                    # stride=doc_stride,
                    truncation=True,
                    # truncation=TruncationStrategy.ONLY_SECOND,
                    # return_offsets_mapping=True,
                    # return_token_type_ids=True,
                    # return_overflowing_tokens=True,
                    return_tensors=TensorType.PYTORCH
                )
                print('encoding: ', encoding.keys())

                # for k, v in encoding.items():
                #     encoding[k] = v.to(self.model.device)

                # num_spans = Funcs.num_spans(encoding)
                # p_mask = Funcs.p_mask(num_spans, encoding)
                # offset_mapping = encoding.pop('offset_mapping')
                # smaple_mapping = encoding.pop('overflow_to_sample_mapping')

                outputs = self.model(**encoding)
                # print('model outputs: ', outputs.keys())
                start_logits = outputs.start_logits
                end_logits = outputs.end_logits

                predicted_start_idx = start_logits.argmax(-1).item()
                predicted_end_idx = end_logits.argmax(-1).item()

                predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
                predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
                # print('answer: ', predicted_answer)
                target_start_index = torch.tensor([7])
                target_end_index = torch.tensor([14])
                outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
                predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
                predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
                # print(predicted_answer_span_start, predicted_answer_span_end)
                logger.info(f'''
    START
    predicted_start_idx: {predicted_start_idx} 
    predicted_end_idx: {predicted_end_idx}
    ---
    answer: {predicted_answer}

    END''')
        return {'data': 'success'}