VQArt / src /models /visual_qa.py
nouman-10's picture
Upload 35 files
ccba2d5
from transformers import pipeline
class VisualQA(object):
def __init__(self, model_name='nflechas/VQArt', tokenizer_name='dandelin/vilt-b32-finetuned-vqa'):
self.model_name = model_name
self.tokenizer_name = tokenizer_name
self.__load_model()
def __load_model(self):
self.model = pipeline('vqa', model=self.model_name, tokenizer=self.tokenizer_name)
def answer_question(self, query, image):
return self.model(question=query, image=image, top_k=1)[0]['answer']