File size: 896 Bytes
253f1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import Any, Dict
from transformers import ViltProcessor, ViltForQuestionAnswering


class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.processor = AutoTokenizer.from_pretrained(path)
        self.model = ViltForQuestionAnswering.from_pretrained(path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        # process input
        image = data.pop("image", data)
        text = data.pop("text", data)
        parameters = data.pop("parameters", None)

        # preprocess
        encoding = processor(image, text, return_tensors="pt")
        outputs = model(**encoding)
        # postprocess the prediction
        logits = outputs.logits
        idx = logits.argmax(-1).item()
        return [{"answer": model.config.id2label[idx]}]