Alexander Slessor commited on
Commit
c0a3632
·
1 Parent(s): 26bcc6f

added handler

Browse files
Files changed (3) hide show
  1. .gitignore +10 -0
  2. handler.py +141 -0
  3. invoice_example.png +0 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.ipynb
3
+ *.pdf
4
+
5
+ test_endpoint.py
6
+ test_handler_local.py
7
+
8
+ setup
9
+ upload_to_hf
10
+ requirements.txt
handler.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import BertForQuestionAnswering, BertTokenizer
3
+ import torch
4
+
5
+ # set device
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # def print_tokens_with_ids(tokenizer, input_ids):
9
+ # # BERT only needs the token IDs, but for the purpose of inspecting the
10
+ # # tokenizer's behavior, let's also get the token strings and display them.
11
+ # tokens = tokenizer.convert_ids_to_tokens(input_ids)
12
+ # # For each token and its id...
13
+ # for token, id in zip(tokens, input_ids):
14
+ # # If this is the [SEP] token, add some space around it to make it stand out.
15
+ # if id == tokenizer.sep_token_id:
16
+ # print('')
17
+ # # Print the token string and its ID in two columns.
18
+ # print('{:<12} {:>6,}'.format(token, id))
19
+ # if id == tokenizer.sep_token_id:
20
+ # print('')
21
+
22
+ def get_segment_ids_aka_token_type_ids(tokenizer, input_ids):
23
+ # Search the input_ids for the first instance of the `[SEP]` token.
24
+ sep_index = input_ids.index(tokenizer.sep_token_id)
25
+ # The number of segment A tokens includes the [SEP] token istelf.
26
+ num_seg_a = sep_index + 1
27
+ # The remainder are segment B.
28
+ num_seg_b = len(input_ids) - num_seg_a
29
+ # Construct the list of 0s and 1s.
30
+ segment_ids = [0]*num_seg_a + [1]*num_seg_b
31
+ # There should be a segment_id for every input token.
32
+ assert len(segment_ids) == len(input_ids), \
33
+ 'There should be a segment_id for every input token.'
34
+ return segment_ids
35
+
36
+ def to_model(
37
+ model: BertForQuestionAnswering,
38
+ input_ids,
39
+ segment_ids
40
+ ) -> tuple:
41
+ # Run input through the model.
42
+ output = model(
43
+ torch.tensor([input_ids]), # The tokens representing our input text.
44
+ token_type_ids=torch.tensor([segment_ids])
45
+ )
46
+ # print(output)
47
+ # print(output.start_logits)
48
+ # print(output.end_logits)
49
+ # print(type(output))
50
+ # The segment IDs to differentiate question from answer_text
51
+ return output.start_logits, output.end_logits
52
+ #output.hidden_states
53
+ #output.attentions
54
+ #output.loss
55
+
56
+ def get_answer(
57
+ start_scores,
58
+ end_scores,
59
+ input_ids,
60
+ tokenizer: BertTokenizer
61
+ ) -> str:
62
+ '''Side Note:
63
+ - It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?!
64
+ - The correct implementation is to pick the highest total score for which end >= start.
65
+ '''
66
+ # Find the tokens with the highest `start` and `end` scores.
67
+ answer_start = torch.argmax(start_scores)
68
+ answer_end = torch.argmax(end_scores)
69
+
70
+ # Combine the tokens in the answer and print it out.
71
+ # answer = ' '.join(tokens[answer_start:answer_end + 1])
72
+ # Get the string versions of the input tokens.
73
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
74
+ # Start with the first token.
75
+ answer = tokens[answer_start]
76
+ # print('Answer: "' + answer + '"')
77
+ # Select the remaining answer tokens and join them with whitespace.
78
+ for i in range(answer_start + 1, answer_end + 1):
79
+ # If it's a subword token, then recombine it with the previous token.
80
+ if tokens[i][0:2] == '##':
81
+ answer += tokens[i][2:]
82
+ # Otherwise, add a space then the token.
83
+ else:
84
+ answer += ' ' + tokens[i]
85
+ return answer
86
+
87
+
88
+ # def resonstruct_words(tokens, answer_start, answer_end):
89
+ # '''reconstruct any words that got broken down into subwords.
90
+ # '''
91
+ # # Start with the first token.
92
+ # answer = tokens[answer_start]
93
+ # # Select the remaining answer tokens and join them with whitespace.
94
+ # for i in range(answer_start + 1, answer_end + 1):
95
+ # # If it's a subword token, then recombine it with the previous token.
96
+ # if tokens[i][0:2] == '##':
97
+ # answer += tokens[i][2:]
98
+ # # Otherwise, add a space then the token.
99
+ # else:
100
+ # answer += ' ' + tokens[i]
101
+ # print('Answer: "' + answer + '"')
102
+
103
+
104
+ class EndpointHandler:
105
+ def __init__(self, path=""):
106
+ self.model = BertForQuestionAnswering.from_pretrained(path).to(device)
107
+ self.tokenizer = BertTokenizer.from_pretrained(path)
108
+
109
+ def __call__(
110
+ self,
111
+ data: Dict[str, str | bytes]
112
+ ):
113
+ """
114
+ Args:
115
+ data (:obj:):
116
+ includes the deserialized image file as PIL.Image
117
+ """
118
+ question = data.pop("question", data)
119
+ context = data.pop("context", data)
120
+
121
+ input_ids = self.tokenizer.encode(question, context)
122
+ # print('The input has a total of {:} tokens.'.format(len(input_ids)))
123
+
124
+ segment_ids = get_segment_ids_aka_token_type_ids(
125
+ self.tokenizer,
126
+ input_ids
127
+ )
128
+ # run prediction
129
+ with torch.inference_mode():
130
+ start_scores, end_scores = to_model(
131
+ self.model,
132
+ input_ids,
133
+ segment_ids
134
+ )
135
+ answer = get_answer(
136
+ start_scores,
137
+ end_scores,
138
+ input_ids,
139
+ self.tokenizer
140
+ )
141
+ return answer
invoice_example.png ADDED