from transformers import pipeline | |
class VisualQA(object): | |
def __init__(self, model_name='nflechas/VQArt', tokenizer_name='dandelin/vilt-b32-finetuned-vqa'): | |
self.model_name = model_name | |
self.tokenizer_name = tokenizer_name | |
self.__load_model() | |
def __load_model(self): | |
self.model = pipeline('vqa', model=self.model_name, tokenizer=self.tokenizer_name) | |
def answer_question(self, query, image): | |
return self.model(question=query, image=image, top_k=1)[0]['answer'] | |