File size: 5,795 Bytes
b12986e c0a3632 b12986e c0a3632 a77ffa1 c0a3632 a77ffa1 c0a3632 edeeba1 c0a3632 edeeba1 c0a3632 b12986e edeeba1 c0a3632 b12986e c0a3632 b12986e c0a3632 b12986e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from typing import Dict, Any
from transformers import BertForQuestionAnswering, BertTokenizer
import torch
# from scipy.special import softmax
# set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# def print_tokens_with_ids(tokenizer, input_ids):
# # BERT only needs the token IDs, but for the purpose of inspecting the
# # tokenizer's behavior, let's also get the token strings and display them.
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# # For each token and its id...
# for token, id in zip(tokens, input_ids):
# # If this is the [SEP] token, add some space around it to make it stand out.
# if id == tokenizer.sep_token_id:
# print('')
# # Print the token string and its ID in two columns.
# print('{:<12} {:>6,}'.format(token, id))
# if id == tokenizer.sep_token_id:
# print('')
def get_segment_ids_aka_token_type_ids(tokenizer, input_ids):
# Search the input_ids for the first instance of the `[SEP]` token.
sep_index = input_ids.index(tokenizer.sep_token_id)
# The number of segment A tokens includes the [SEP] token istelf.
num_seg_a = sep_index + 1
# The remainder are segment B.
num_seg_b = len(input_ids) - num_seg_a
# Construct the list of 0s and 1s.
segment_ids = [0]*num_seg_a + [1]*num_seg_b
# There should be a segment_id for every input token.
assert len(segment_ids) == len(input_ids), \
'There should be a segment_id for every input token.'
return segment_ids
def to_model(
model: BertForQuestionAnswering,
) -> tuple:
# Run input through the model.
output = model(
torch.tensor([input_ids]), # The tokens representing our input text.
# print(output)
# print(output.start_logits)
# print(output.end_logits)
# print(type(output))
# The segment IDs to differentiate question from answer_text
return output.start_logits, output.end_logits
def get_answer(
tokenizer: BertTokenizer
) -> str:
'''Side Note:
- It’s a little naive to pick the highest scores for start and end–what if it predicts an end word that’s before the start word?!
- The correct implementation is to pick the highest total score for which end >= start.
# Find the tokens with the highest `start` and `end` scores.
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Combine the tokens in the answer and print it out.
# answer = ' '.join(tokens[answer_start:answer_end + 1])
# Get the string versions of the input tokens.
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Start with the first token.
answer = tokens[answer_start]
# print('Answer: "' + answer + '"')
# Select the remaining answer tokens and join them with whitespace.
for i in range(answer_start + 1, answer_end + 1):
# If it's a subword token, then recombine it with the previous token.
if tokens[i][0:2] == '##':
answer += tokens[i][2:]
# Otherwise, add a space then the token.
answer += ' ' + tokens[i]
return answer
# def resonstruct_words(tokens, answer_start, answer_end):
# '''reconstruct any words that got broken down into subwords.
# '''
# # Start with the first token.
# answer = tokens[answer_start]
# # Select the remaining answer tokens and join them with whitespace.
# for i in range(answer_start + 1, answer_end + 1):
# # If it's a subword token, then recombine it with the previous token.
# if tokens[i][0:2] == '##':
# answer += tokens[i][2:]
# # Otherwise, add a space then the token.
# else:
# answer += ' ' + tokens[i]
# print('Answer: "' + answer + '"')
class EndpointHandler:
def __init__(self, path=""):
# self.model = BertForQuestionAnswering.from_pretrained(path).to(device)
self.model = BertForQuestionAnswering.from_pretrained(path)
self.tokenizer = BertTokenizer.from_pretrained(path)
# def __call__(self, data: Dict[str, Any]):
# def __call__(self, data: dict[str, Any]) -> dict[str, list[Any]]:
def __call__(self, data: dict[str, Any]):
data (:obj:):
includes the context and question
if 'inputs' not in data:
raise ValueError('no inputs key in data')
i = data.pop("inputs", data)
question = i.pop("question", False)
context = i.pop("context", False)
if question is False and context is False:
raise ValueError(
f'No question and/or context: question: {question} - context: {context}')
input_ids = self.tokenizer.encode(question, context)
# print('The input has a total of {:} tokens.'.format(len(input_ids)))
segment_ids = get_segment_ids_aka_token_type_ids(
# run prediction
with torch.inference_mode():
start_scores, end_scores = to_model(
answer = get_answer(
return answer
except Exception as e: