import sys | |
sys.path.append(".") | |
import streamlit as st | |
import pandas as pd | |
from PIL import Image | |
from model_loader import * | |
from datasets import load_dataset | |
# load dataset | |
#ds = load_dataset("test") | |
# ds = load_dataset("HuggingFaceM4/VQAv2", split="validation", cache_dir="cache", streaming=False) | |
df = pd.read_json('vqa_samples.json', orient="columns") | |
# define selector | |
model_name = st.sidebar.selectbox( | |
"Select a model: ", | |
('vilt', 'git', 'blip', 'vbert') | |
) | |
image_selector_unspecific = st.number_input( | |
"Select an image id: ", | |
0, len(df) | |
) | |
# select and display | |
#sample = ds[image_selector_unspecific] | |
sample = df.iloc[image_selector_unspecific] | |
img_path = sample['img_path'] | |
image = Image.open(f'images/{img_path}.jpg') | |
st.image(image, channels="RGB") | |
question = sample['ques'] | |
label = sample['label'] | |
# inference | |
question = st.text_input(f"Ask the model a question related to the image: \n" | |
f"(e.g. \"{sample['question']}\")") | |
args = load_model(model_name) # TODO: cache | |
answer = get_answer(args, image, question, model_name) | |
st.text(f"Answer by {model_name}: {answer}") | |
st.text(f"Ground truth: {label}") |