Spaces:
Runtime error
Runtime error
Commit
·
f4963f2
1
Parent(s):
d823ba7
Add about and method in app
Browse files- app.py +75 -31
- session.py +89 -0
- utils.py +2 -2
app.py
CHANGED
@@ -11,6 +11,10 @@ from mtranslate import translate
|
|
11 |
from PIL import Image
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
14 |
@st.cache
|
15 |
def load_model(ckpt):
|
16 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
@@ -24,7 +28,6 @@ with open('answer_reverse_mapping.json') as f:
|
|
24 |
answer_reverse_mapping = json.load(f)
|
25 |
|
26 |
|
27 |
-
|
28 |
st.set_page_config(
|
29 |
page_title="Multilingual VQA",
|
30 |
layout="wide",
|
@@ -34,58 +37,99 @@ st.set_page_config(
|
|
34 |
|
35 |
st.title("Multilingual Visual Question Answering")
|
36 |
|
|
|
|
|
37 |
with st.beta_expander("About"):
|
38 |
-
|
39 |
with st.beta_expander("Method"):
|
40 |
-
st.
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
pass
|
43 |
|
44 |
# Init Session State
|
45 |
-
if
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
image_path = os.path.join('images',
|
53 |
image = plt.imread(image_path)
|
54 |
-
|
55 |
|
56 |
col1, col2 = st.beta_columns([5,5])
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
sample = dummy_data.sample(1).reset_index()
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
-
image_path = os.path.join('images',
|
66 |
image = plt.imread(image_path)
|
67 |
-
|
68 |
|
|
|
69 |
|
70 |
uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
|
71 |
if uploaded_file is not None:
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
|
76 |
-
transformed_image = get_transformed_image(st.session_state.image)
|
77 |
|
78 |
-
|
79 |
-
st.image(st.session_state.image, use_column_width='always')
|
80 |
|
81 |
# Display Question
|
82 |
-
question = st.text_input(label="Question", value=
|
83 |
-
st.markdown(f"""**English Translation**: {question if
|
84 |
question_inputs = get_text_attributes(question)
|
85 |
|
86 |
# Select Language
|
87 |
options = ['en', 'de', 'es', 'fr']
|
88 |
-
|
89 |
# Display Top-5 Predictions
|
90 |
with st.spinner('Loading model...'):
|
91 |
model = load_model(checkpoints[0])
|
@@ -94,6 +138,6 @@ with st.spinner('Predicting...'):
|
|
94 |
logits = np.array(predictions[0][0])
|
95 |
logits = softmax(logits)
|
96 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
97 |
-
translated_labels = translate_labels(labels,
|
98 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
99 |
-
st.plotly_chart(fig)
|
|
|
11 |
from PIL import Image
|
12 |
|
13 |
|
14 |
+
from session import _get_state
|
15 |
+
|
16 |
+
state = _get_state()
|
17 |
+
|
18 |
@st.cache
|
19 |
def load_model(ckpt):
|
20 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
|
|
28 |
answer_reverse_mapping = json.load(f)
|
29 |
|
30 |
|
|
|
31 |
st.set_page_config(
|
32 |
page_title="Multilingual VQA",
|
33 |
layout="wide",
|
|
|
37 |
|
38 |
st.title("Multilingual Visual Question Answering")
|
39 |
|
40 |
+
|
41 |
+
|
42 |
with st.beta_expander("About"):
|
43 |
+
st.write("This project is focused on Mutilingual Visual Question Answering. Most of the existing datasets and models on this task work with English-only image-text pairs. Our intention here is to provide a Proof-of-Concept with our simple ViT+BERT model which can be trained on multilingual text checkpoints with pre-trained image encoders and well enough. Due to lack of good-quality multilingual data, we translate subsets of the Conceptual 12M dataset into English (already in English), French, German and Spanish using the Marian models. We achieved 0.49 accuracy on the multilingual validation set we created. With better captions, and hyperparameter-tuning, we expect to see higher performance.")
|
44 |
with st.beta_expander("Method"):
|
45 |
+
col1, col2 = st.beta_columns([5,4])
|
46 |
+
col1.image("./misc/Multilingual-VQA.png")
|
47 |
+
col2.markdown("""
|
48 |
+
## Pretraining
|
49 |
+
We follow an approach similar to [VisualBERT](https://arxiv.org/abs/1908.03557). Instead of using a FasterRCNN to get image features, we use a ViT encoder.
|
50 |
+
The task is text-only MLM (Masked Language Modeling). We mask only the text tokens and try to predict the masked tokens. The VisualBERT authors also use a sentence-image matching task where two captions are matched against an image, but we skip this for the sake of simplicity.
|
51 |
+
### Dataset
|
52 |
+
The dataset we use for pre-training is a cleaned version of [Conceptual 12M](https://github.com/google-research-datasets/conceptual-12m). The dataset is downloaded and then broken images are removed which gives us about 10M images. Then we use the MBart50 `
|
53 |
+
mbart-large-50-one-to-many-mmt` checkpoint to translate the dataset into four different languages - English, French, German, and Spanish, keeping 2.5 million examples of each language.
|
54 |
+
""")
|
55 |
+
|
56 |
+
st.markdown("""
|
57 |
+
### Model
|
58 |
+
The model is shown in the image above.We create a custom model in Flax which integerates the ViT model inside BERT embeddings. We also use custom configs and modules in order to accomodate for these changes, and allow loading from BERT and ViT checkpoints. The image is fed to the ViT encoder and the text is fed to the word-embedding layers of BERT model. We use the `bert-base-multilingual-uncased` and `openai/clip-vit-base-patch32` checkpoints for BERT and ViT (actually CLIPVision) models, respectively. All our code is available on [GitHub](https://github.com/gchhablani/multilingual-vqa).
|
59 |
+
## Fine-tuning
|
60 |
+
|
61 |
+
### Dataset
|
62 |
+
For fine-tuning, we use the [VQA 2.0](https://visualqa.org/) dataset - particularly, the `train` and `validation` sets. We translate all the questions into the four languages specified above using language-specific MarianMT models. This is because MarianMT models return better labels and are faster, hence, are better for fine-tuning. We get 4x the number of examples in each subset.
|
63 |
+
### Model
|
64 |
+
We use the `SequenceClassification` model as reference to create our own sequence classification model. 3129 answer labels are chosen, as is the convention for the English VQA task, which can be found [here](https://github.com/gchhablani/multilingual-vqa/blob/main/answer_mapping.json). These are the same labels used in fine-tuning of the VisualBERT models. The outputs shown here have been translated using the [`mtranslate`](https://github.com/mouuff/mtranslate) Google Translate API library. Then we use various pre-trained checkpoints and train the sequence classification model for various steps.
|
65 |
+
|
66 |
+
Checkpoints:
|
67 |
+
- Pre-trained checkpoint: [multilingual-vqa](https://huggingface.co/flax-community/multilingual-vqa)
|
68 |
+
- Fine-tuned on 45k pretrained checkpoint: [multilingual-vqa-pt-45k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft)
|
69 |
+
- Fine-tuned on 45k pretrained checkpoint with AdaFactor (others use AdamW): [multilingual-vqa-pt-45k-ft-adf](https://huggingface.co/flax-community/multilingual-vqa-pt-45k-ft-adf)
|
70 |
+
- Fine-tuned on 60k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-60k-ft)
|
71 |
+
- Fine-tuned on 70k pretrained checkpoint: [multilingual-vqa-pt-60k-ft](https://huggingface.co/flax-community/multilingual-vqa-pt-70k-ft)
|
72 |
+
- From scratch (without pre-training) model: [multilingual-vqa-ft](https://huggingface.co/flax-community/multilingual-vqa-ft)
|
73 |
+
|
74 |
+
**Caveat**: The best fine-tuned model only achieves 0.49 accuracy on the multilingual validation data that we create. This could be because of not-so-great quality translations, sub-optimal hyperparameters and lack of ample training. In future, we hope to improve this model by addressing such concerns.
|
75 |
+
""")
|
76 |
+
|
77 |
+
with st.beta_expander("Cherry-Picked Results"):
|
78 |
+
pass
|
79 |
+
|
80 |
+
with st.beta_expander("Conclusion"):
|
81 |
+
pass
|
82 |
+
|
83 |
+
with st.beta_expander("Usage"):
|
84 |
pass
|
85 |
|
86 |
# Init Session State
|
87 |
+
if state.image_file is None:
|
88 |
+
state.image_file = dummy_data.loc[0,'image_file']
|
89 |
+
state.question = dummy_data.loc[0,'question'].strip('- ')
|
90 |
+
state.answer_label = dummy_data.loc[0,'answer_label']
|
91 |
+
state.question_lang_id = dummy_data.loc[0, 'lang_id']
|
92 |
+
state.answer_lang_id = dummy_data.loc[0, 'lang_id']
|
93 |
|
94 |
+
image_path = os.path.join('images',state.image_file)
|
95 |
image = plt.imread(image_path)
|
96 |
+
state.image = image
|
97 |
|
98 |
col1, col2 = st.beta_columns([5,5])
|
99 |
+
|
100 |
+
# Display Image
|
101 |
+
col1.image(state.image, use_column_width='always')
|
102 |
+
|
103 |
+
if col2.button('Get a random example'):
|
104 |
sample = dummy_data.sample(1).reset_index()
|
105 |
+
state.image_file = sample.loc[0,'image_file']
|
106 |
+
state.question = sample.loc[0,'question'].strip('- ')
|
107 |
+
state.answer_label = sample.loc[0,'answer_label']
|
108 |
+
state.question_lang_id = sample.loc[0, 'lang_id']
|
109 |
+
state.answer_lang_id = sample.loc[0, 'lang_id']
|
110 |
|
111 |
+
image_path = os.path.join('images',state.image_file)
|
112 |
image = plt.imread(image_path)
|
113 |
+
state.image = image
|
114 |
|
115 |
+
st.write("OR")
|
116 |
|
117 |
uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg'])
|
118 |
if uploaded_file is not None:
|
119 |
+
state.image_file = os.path.join('images/val2014',uploaded_file.name)
|
120 |
+
state.image = np.array(Image.open(uploaded_file))
|
|
|
121 |
|
|
|
122 |
|
123 |
+
transformed_image = get_transformed_image(state.image)
|
|
|
124 |
|
125 |
# Display Question
|
126 |
+
question = st.text_input(label="Question", value=state.question)
|
127 |
+
st.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""")
|
128 |
question_inputs = get_text_attributes(question)
|
129 |
|
130 |
# Select Language
|
131 |
options = ['en', 'de', 'es', 'fr']
|
132 |
+
state.answer_lang_id = st.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options)
|
133 |
# Display Top-5 Predictions
|
134 |
with st.spinner('Loading model...'):
|
135 |
model = load_model(checkpoints[0])
|
|
|
138 |
logits = np.array(predictions[0][0])
|
139 |
logits = softmax(logits)
|
140 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
141 |
+
translated_labels = translate_labels(labels, state.answer_lang_id)
|
142 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
143 |
+
st.plotly_chart(fig, use_container_width = True)
|
session.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Code for managing session state, which is needed for multi-input forms
|
3 |
+
# See https://github.com/streamlit/streamlit/issues/1557
|
4 |
+
#
|
5 |
+
# This code is taken from
|
6 |
+
# https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
|
7 |
+
#
|
8 |
+
from streamlit.hashing import _CodeHasher
|
9 |
+
from streamlit.report_thread import get_report_ctx
|
10 |
+
from streamlit.server.server import Server
|
11 |
+
|
12 |
+
|
13 |
+
class _SessionState:
|
14 |
+
def __init__(self, session, hash_funcs):
|
15 |
+
"""Initialize SessionState instance."""
|
16 |
+
self.__dict__["_state"] = {
|
17 |
+
"data": {},
|
18 |
+
"hash": None,
|
19 |
+
"hasher": _CodeHasher(hash_funcs),
|
20 |
+
"is_rerun": False,
|
21 |
+
"session": session,
|
22 |
+
}
|
23 |
+
|
24 |
+
def __call__(self, **kwargs):
|
25 |
+
"""Initialize state data once."""
|
26 |
+
for item, value in kwargs.items():
|
27 |
+
if item not in self._state["data"]:
|
28 |
+
self._state["data"][item] = value
|
29 |
+
|
30 |
+
def __getitem__(self, item):
|
31 |
+
"""Return a saved state value, None if item is undefined."""
|
32 |
+
return self._state["data"].get(item, None)
|
33 |
+
|
34 |
+
def __getattr__(self, item):
|
35 |
+
"""Return a saved state value, None if item is undefined."""
|
36 |
+
return self._state["data"].get(item, None)
|
37 |
+
|
38 |
+
def __setitem__(self, item, value):
|
39 |
+
"""Set state value."""
|
40 |
+
self._state["data"][item] = value
|
41 |
+
|
42 |
+
def __setattr__(self, item, value):
|
43 |
+
"""Set state value."""
|
44 |
+
self._state["data"][item] = value
|
45 |
+
|
46 |
+
def clear(self):
|
47 |
+
"""Clear session state and request a rerun."""
|
48 |
+
self._state["data"].clear()
|
49 |
+
self._state["session"].request_rerun()
|
50 |
+
|
51 |
+
def sync(self):
|
52 |
+
"""
|
53 |
+
Rerun the app with all state values up to date from the beginning to
|
54 |
+
fix rollbacks.
|
55 |
+
"""
|
56 |
+
data_to_bytes = self._state["hasher"].to_bytes(self._state["data"], None)
|
57 |
+
|
58 |
+
# Ensure to rerun only once to avoid infinite loops
|
59 |
+
# caused by a constantly changing state value at each run.
|
60 |
+
#
|
61 |
+
# Example: state.value += 1
|
62 |
+
if self._state["is_rerun"]:
|
63 |
+
self._state["is_rerun"] = False
|
64 |
+
|
65 |
+
elif self._state["hash"] is not None:
|
66 |
+
if self._state["hash"] != data_to_bytes:
|
67 |
+
self._state["is_rerun"] = True
|
68 |
+
self._state["session"].request_rerun()
|
69 |
+
|
70 |
+
self._state["hash"] = data_to_bytes
|
71 |
+
|
72 |
+
|
73 |
+
def _get_session():
|
74 |
+
session_id = get_report_ctx().session_id
|
75 |
+
session_info = Server.get_current()._get_session_info(session_id)
|
76 |
+
|
77 |
+
if session_info is None:
|
78 |
+
raise RuntimeError("Couldn't get your Streamlit Session object.")
|
79 |
+
|
80 |
+
return session_info.session
|
81 |
+
|
82 |
+
|
83 |
+
def _get_state(hash_funcs=None):
|
84 |
+
session = _get_session()
|
85 |
+
|
86 |
+
if not hasattr(session, "_custom_session_state"):
|
87 |
+
session._custom_session_state = _SessionState(session, hash_funcs)
|
88 |
+
|
89 |
+
return session._custom_session_state
|
utils.py
CHANGED
@@ -45,7 +45,7 @@ def get_text_attributes(text):
|
|
45 |
|
46 |
def get_top_5_predictions(logits, answer_reverse_mapping):
|
47 |
indices = np.argsort(logits)[-5:]
|
48 |
-
values =
|
49 |
labels = [answer_reverse_mapping[str(i)] for i in indices]
|
50 |
return labels, values
|
51 |
|
@@ -65,5 +65,5 @@ def translate_labels(labels, lang_id):
|
|
65 |
|
66 |
|
67 |
def plotly_express_horizontal_bar_plot(values, labels):
|
68 |
-
fig = px.bar(x=values, y=labels, text = values, title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
|
69 |
return fig
|
|
|
45 |
|
46 |
def get_top_5_predictions(logits, answer_reverse_mapping):
|
47 |
indices = np.argsort(logits)[-5:]
|
48 |
+
values = logits[indices]
|
49 |
labels = [answer_reverse_mapping[str(i)] for i in indices]
|
50 |
return labels, values
|
51 |
|
|
|
65 |
|
66 |
|
67 |
def plotly_express_horizontal_bar_plot(values, labels):
|
68 |
+
fig = px.bar(x=values, y=labels, text = [format(value, ".3%") for value in values], title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
|
69 |
return fig
|