Spaces:
Runtime error
Runtime error
File size: 4,372 Bytes
405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import json
from mtranslate import translate
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForSequenceClassification,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
vqa_state = state
# @st.cache(persist=False)
def predict(transformed_image, question_inputs):
return np.array(
model(pixel_values=transformed_image, **question_inputs)[0][0]
)
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
vqa_checkpoints = [
"flax-community/clip-vision-bert-vqa-ft-6k"
] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
first_index = 20
# Init Session vqa_state
if vqa_state.vqa_image_file is None:
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
# if model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
model = load_model(vqa_checkpoints[0])
if st.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
vqa_state.vqa_image_file = sample.loc[0, "image_file"]
vqa_state.question = sample.loc[0, "question"].strip("- ")
vqa_state.answer_label = sample.loc[0, "answer_label"]
vqa_state.question_lang_id = sample.loc[0, "lang_id"]
vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
image = plt.imread(image_path)
vqa_state.vqa_image = image
transformed_image = get_transformed_image(vqa_state.vqa_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(vqa_state.vqa_image, use_column_width="always")
# Display Question
question = new_col2.text_input(
label="Question",
value=vqa_state.question,
help="Type your question regarding the image above in one of the four languages.",
)
new_col2.markdown(
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
)
question_inputs = get_text_attributes(question)
# Select Language
options = ["en", "de", "es", "fr"]
vqa_state.answer_lang_id = new_col2.selectbox(
"Answer Language",
index=options.index(vqa_state.answer_lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language to be used to show the top-5 labels.",
)
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
new_col2.markdown(
"**Actual Answer**: "
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
+ " ("
+ actual_answer
+ ")"
)
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(question_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig, use_container_width=True)
|