uptake-model / handler.py
Arthur Pan
Fixed Error due to Tensors Located on Different Files
0fa9e59 verified
raw
history blame
3.09 kB
from typing import Dict, List, Any
from scipy.special import softmax
from utils import clean_str, clean_str_nopunct
import torch
from transformers import BertTokenizer
from utils import MultiHeadModel, BertInputBuilder, get_num_words
MODEL_CHECKPOINT='ddemszky/uptake-model'
class EndpointHandler():
def __init__(self, path="."):
print("Loading models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
self.model.to(self.device)
self.max_length = 120
def get_clean_text(self, text, remove_punct=False):
if remove_punct:
return clean_str_nopunct(text)
return clean_str(text)
def get_prediction(self, instance):
instance["attention_mask"] = [[1] * len(instance["input_ids"])]
for key in ["input_ids", "token_type_ids", "attention_mask"]:
instance[key] = torch.tensor(instance[key]).unsqueeze(0).to(self.device) # Batch size = 1
output = self.model(input_ids=instance["input_ids"],
attention_mask=instance["attention_mask"],
token_type_ids=instance["token_type_ids"],
return_pooler_output=False)
return output
def get_uptake_score(self, textA, textB):
textA = self.get_clean_text(textA, remove_punct=False)
textB = self.get_clean_text(textB, remove_punct=False)
instance = self.input_builder.build_inputs([textA], textB,
max_length=self.max_length,
input_str=True)
output = self.get_prediction(instance)
uptake_score = softmax(output["nsp_logits"][0].tolist())[1]
return uptake_score
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
utterances (:obj: `list`)
parameters (:obj: `dict`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
utterances = data.pop("inputs", data)
params = data.pop("parameters", None)
print("EXAMPLES")
for utt in utterances[:3]:
print("speaker %s: %s" % (utt["speaker"], utt["text"]))
print("Running inference on %d examples..." % len(utterances))
self.model.eval()
prev_num_words = 0
prev_text = ""
uptake_scores = {}
with torch.no_grad():
for i, utt in enumerate(utterances):
if utt["speaker"] == params["speaker_2"] and (prev_num_words >= params["speaker_1_min_num_words"]):
uptake_scores[str(utt["id"])] = self.get_uptake_score(textA=prev_text, textB=utt["text"])
prev_num_words = get_num_words(utt["text"])
prev_text = utt["text"]
return uptake_scores