from typing import Any, Dict | |
from transformers import ViltProcessor, ViltForQuestionAnswering | |
class EndpointHandler: | |
def __init__(self, path=""): | |
# load model and processor from path | |
self.processor = AutoTokenizer.from_pretrained(path) | |
self.model = ViltForQuestionAnswering.from_pretrained(path) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: | |
# process input | |
image = data.pop("image", data) | |
text = data.pop("text", data) | |
parameters = data.pop("parameters", None) | |
# preprocess | |
encoding = processor(image, text, return_tensors="pt") | |
outputs = model(**encoding) | |
# postprocess the prediction | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
return [{"answer": model.config.id2label[idx]}] |