mt0-xxl / handler.py
AIDSC's picture
Create handler.py
45b837c verified
raw
history blame
1.81 kB
from typing import Dict, List, Any
import torch
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
def softmax(x):
z = x - max(x)
numerator = np.exp(z)
denominator = np.sum(numerator)
softmax = numerator/denominator
return softmax
class EndpointHandler():
def __init__(self, path=""):
self.accelerator = Accelerator()
self.device = self.accelerator.device
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, trust_remote_code=True, device_map="auto")
self.model = self.accelerator.prepare(self.model)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.options_tokens = [self.tokenizer.encode(choice)[0] for choice in ["A", "B", "C", "D"]]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargss
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
with torch.no_grad():
prompt = data.pop("prompt")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_size = inputs['input_ids'].size(1)
input_ids = inputs["input_ids"].to(self.device)
start_token = self.tokenizer('<pad>', return_tensors="pt").to(self.device)
outputs = self.model(**inputs, decoder_input_ids=start_token['input_ids'])
last_token_logits = outputs.logits[:, -1, :]
options_tokens_logits = last_token_logits[:, self.options_tokens].detach().cpu().numpy()
conf = softmax(options_tokens_logits[0])
pred = np.argmax(options_tokens_logits[0])
return [{"pred": pred, "conf":conf}]