File size: 3,368 Bytes
2bbf92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from io import BytesIO
import streamlit as st
import pandas as pd
import json
import os
import numpy as np
from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification
from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels
import matplotlib.pyplot as plt
from mtranslate import translate
from PIL import Image


@st.cache
def load_model(ckpt):
    return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)

def softmax(logits):
    return np.exp(logits)/np.sum(np.exp(logits), axis=0)

checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t')
with open('answer_reverse_mapping.json') as f:
    answer_reverse_mapping = json.load(f)

# Init Session State
if 'image_file' not in st.session_state:
    st.session_state.image_file = dummy_data.loc[0,'image_file']
    st.session_state.question = dummy_data.loc[0,'question']
    st.session_state.answer_label = dummy_data.loc[0,'answer_label']
    st.session_state.question_lang_id = dummy_data.loc[0, 'lang_id']
    st.session_state.answer_lang_id = dummy_data.loc[0, 'lang_id']
    
    image_path = os.path.join('images',st.session_state.image_file)
    image = plt.imread(image_path)
    st.session_state.image = image

col1, col2 = st.beta_columns([5,5])
if col1.button('Get a Random Example'):
    sample = dummy_data.sample(1).reset_index()
    st.session_state.image_file = sample.loc[0,'image_file']
    st.session_state.question = sample.loc[0,'question']
    st.session_state.answer_label = sample.loc[0,'answer_label']
    st.session_state.question_lang_id = sample.loc[0, 'lang_id']
    st.session_state.answer_lang_id = sample.loc[0, 'lang_id']

    image_path = os.path.join('images',st.session_state.image_file)
    image = plt.imread(image_path)
    st.session_state.image = image


uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
if uploaded_file is not None:
    st.session_state.image_file = os.path.join('images/val2014',uploaded_file.name)
    st.session_state.image = np.array(Image.open(uploaded_file))
    

transformed_image = get_transformed_image(st.session_state.image)

# Display Image
st.image(st.session_state.image, use_column_width='always')

# Display Question
question = st.text_input(label="Question", value=st.session_state.question)
st.markdown(f"""**English Translation**: {question if st.session_state.question_lang_id == "en" else translate(question, 'en')}""")
question_inputs = get_text_attributes(question)

# Select Language
options = ['en', 'de', 'es', 'fr']
st.session_state.answer_lang_id = st.selectbox('Answer Language', index=options.index(st.session_state.answer_lang_id), options=options)
# Display Top-5 Predictions
with st.spinner('Loading model...'):
    model = load_model(checkpoints[0])
with st.spinner('Predicting...'):
    predictions = model(pixel_values = transformed_image, **question_inputs)
logits = np.array(predictions[0][0])
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, st.session_state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig)