Spaces:
Runtime error
Runtime error
Commit
·
e289356
1
Parent(s):
7a89f67
Fix style
Browse files- app.py +41 -21
- translate_answer_mapping.py +4 -3
- utils.py +6 -5
app.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
-
from io import BytesIO
|
2 |
-
import streamlit as st
|
3 |
-
import pandas as pd
|
4 |
import json
|
5 |
import os
|
|
|
|
|
|
|
6 |
import numpy as np
|
7 |
-
|
|
|
|
|
8 |
from PIL import Image
|
|
|
|
|
9 |
from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
10 |
FlaxCLIPVisionBertForSequenceClassification,
|
11 |
)
|
|
|
12 |
from utils import (
|
13 |
-
get_transformed_image,
|
14 |
get_text_attributes,
|
15 |
get_top_5_predictions,
|
|
|
16 |
plotly_express_horizontal_bar_plot,
|
17 |
translate_labels,
|
18 |
)
|
19 |
-
import matplotlib.pyplot as plt
|
20 |
-
from mtranslate import translate
|
21 |
-
|
22 |
-
|
23 |
-
from session import _get_state
|
24 |
|
25 |
state = _get_state()
|
26 |
|
@@ -74,9 +74,9 @@ st.write(
|
|
74 |
"[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
|
75 |
)
|
76 |
|
77 |
-
image_col, intro_col = st.beta_columns([3,8])
|
78 |
-
image_col.image("./misc/mvqa-logo-white.png", use_column_width=
|
79 |
-
intro_col.write(read_markdown(
|
80 |
with st.beta_expander("Usage"):
|
81 |
st.write(read_markdown("usage.md"))
|
82 |
|
@@ -85,7 +85,8 @@ with st.beta_expander("Article"):
|
|
85 |
st.write(read_markdown("caveats.md"))
|
86 |
st.write("## Methodology")
|
87 |
st.image(
|
88 |
-
"./misc/Multilingual-VQA.png",
|
|
|
89 |
)
|
90 |
st.markdown(read_markdown("pretraining.md"))
|
91 |
st.markdown(read_markdown("finetuning.md"))
|
@@ -110,7 +111,10 @@ if state.image_file is None:
|
|
110 |
|
111 |
col1, col2 = st.beta_columns([6, 4])
|
112 |
|
113 |
-
if col2.button(
|
|
|
|
|
|
|
114 |
sample = dummy_data.sample(1).reset_index()
|
115 |
state.image_file = sample.loc[0, "image_file"]
|
116 |
state.question = sample.loc[0, "question"].strip("- ")
|
@@ -124,9 +128,15 @@ if col2.button("Get a random example", help="Get a random example from the 100 `
|
|
124 |
|
125 |
col2.write("OR")
|
126 |
|
127 |
-
uploaded_file = col2.file_uploader(
|
|
|
|
|
|
|
|
|
128 |
if uploaded_file is not None:
|
129 |
-
st.error(
|
|
|
|
|
130 |
# state.image_file = os.path.join("images/val2014", uploaded_file.name)
|
131 |
# state.image = np.array(Image.open(uploaded_file))
|
132 |
|
@@ -135,9 +145,13 @@ transformed_image = get_transformed_image(state.image)
|
|
135 |
# Display Image
|
136 |
col1.image(state.image, use_column_width="auto")
|
137 |
|
138 |
-
new_col1, new_col2 = st.beta_columns([5,5])
|
139 |
# Display Question
|
140 |
-
question = new_col1.text_input(
|
|
|
|
|
|
|
|
|
141 |
new_col1.markdown(
|
142 |
f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
|
143 |
)
|
@@ -151,11 +165,17 @@ state.answer_lang_id = new_col2.selectbox(
|
|
151 |
index=options.index(state.answer_lang_id),
|
152 |
options=options,
|
153 |
format_func=lambda x: code_to_name[x],
|
154 |
-
help="The language to be used to show the top-5 labels."
|
155 |
)
|
156 |
|
157 |
actual_answer = answer_reverse_mapping[str(state.answer_label)]
|
158 |
-
new_col2.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Display Top-5 Predictions
|
161 |
with st.spinner("Loading model..."):
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import streamlit as st
|
9 |
+
from mtranslate import translate
|
10 |
from PIL import Image
|
11 |
+
from streamlit.elements import markdown
|
12 |
+
|
13 |
from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
14 |
FlaxCLIPVisionBertForSequenceClassification,
|
15 |
)
|
16 |
+
from session import _get_state
|
17 |
from utils import (
|
|
|
18 |
get_text_attributes,
|
19 |
get_top_5_predictions,
|
20 |
+
get_transformed_image,
|
21 |
plotly_express_horizontal_bar_plot,
|
22 |
translate_labels,
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
state = _get_state()
|
26 |
|
|
|
74 |
"[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
|
75 |
)
|
76 |
|
77 |
+
image_col, intro_col = st.beta_columns([3, 8])
|
78 |
+
image_col.image("./misc/mvqa-logo-white.png", use_column_width="always")
|
79 |
+
intro_col.write(read_markdown("intro.md"))
|
80 |
with st.beta_expander("Usage"):
|
81 |
st.write(read_markdown("usage.md"))
|
82 |
|
|
|
85 |
st.write(read_markdown("caveats.md"))
|
86 |
st.write("## Methodology")
|
87 |
st.image(
|
88 |
+
"./misc/Multilingual-VQA.png",
|
89 |
+
caption="Masked LM model for Image-text Pretraining.",
|
90 |
)
|
91 |
st.markdown(read_markdown("pretraining.md"))
|
92 |
st.markdown(read_markdown("finetuning.md"))
|
|
|
111 |
|
112 |
col1, col2 = st.beta_columns([6, 4])
|
113 |
|
114 |
+
if col2.button(
|
115 |
+
"Get a random example",
|
116 |
+
help="Get a random example from the 100 `seeded` image-text pairs.",
|
117 |
+
):
|
118 |
sample = dummy_data.sample(1).reset_index()
|
119 |
state.image_file = sample.loc[0, "image_file"]
|
120 |
state.question = sample.loc[0, "question"].strip("- ")
|
|
|
128 |
|
129 |
col2.write("OR")
|
130 |
|
131 |
+
uploaded_file = col2.file_uploader(
|
132 |
+
"Upload your image",
|
133 |
+
type=["png", "jpg", "jpeg"],
|
134 |
+
help="Upload a file of your choosing.",
|
135 |
+
)
|
136 |
if uploaded_file is not None:
|
137 |
+
st.error(
|
138 |
+
"Uploading files does not work on HuggingFace spaces. This app only supports random examples for now."
|
139 |
+
)
|
140 |
# state.image_file = os.path.join("images/val2014", uploaded_file.name)
|
141 |
# state.image = np.array(Image.open(uploaded_file))
|
142 |
|
|
|
145 |
# Display Image
|
146 |
col1.image(state.image, use_column_width="auto")
|
147 |
|
148 |
+
new_col1, new_col2 = st.beta_columns([5, 5])
|
149 |
# Display Question
|
150 |
+
question = new_col1.text_input(
|
151 |
+
label="Question",
|
152 |
+
value=state.question,
|
153 |
+
help="Type your question regarding the image above in one of the four languages.",
|
154 |
+
)
|
155 |
new_col1.markdown(
|
156 |
f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
|
157 |
)
|
|
|
165 |
index=options.index(state.answer_lang_id),
|
166 |
options=options,
|
167 |
format_func=lambda x: code_to_name[x],
|
168 |
+
help="The language to be used to show the top-5 labels.",
|
169 |
)
|
170 |
|
171 |
actual_answer = answer_reverse_mapping[str(state.answer_label)]
|
172 |
+
new_col2.markdown(
|
173 |
+
"**Actual Answer**: "
|
174 |
+
+ translate_labels([actual_answer], state.answer_lang_id)[0]
|
175 |
+
+ " ("
|
176 |
+
+ actual_answer
|
177 |
+
+ ")"
|
178 |
+
)
|
179 |
|
180 |
# Display Top-5 Predictions
|
181 |
with st.spinner("Loading model..."):
|
translate_answer_mapping.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from mtranslate.core import translate
|
2 |
import json
|
3 |
-
from tqdm import tqdm
|
4 |
-
import ray
|
5 |
from asyncio import Event
|
|
|
|
|
|
|
6 |
from ray.actor import ActorHandle
|
|
|
7 |
|
8 |
ray.init()
|
9 |
from typing import Tuple
|
|
|
|
|
1 |
import json
|
|
|
|
|
2 |
from asyncio import Event
|
3 |
+
|
4 |
+
import ray
|
5 |
+
from mtranslate.core import translate
|
6 |
from ray.actor import ActorHandle
|
7 |
+
from tqdm import tqdm
|
8 |
|
9 |
ray.init()
|
10 |
from typing import Tuple
|
utils.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
-
|
2 |
-
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
4 |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
5 |
from torchvision.transforms.functional import InterpolationMode
|
6 |
from transformers import BertTokenizerFast
|
7 |
-
import plotly.express as px
|
8 |
-
import json
|
9 |
-
from PIL import Image
|
10 |
|
11 |
|
12 |
class Transform(torch.nn.Module):
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
import numpy as np
|
4 |
+
import plotly.express as px
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.io import ImageReadMode, read_image
|
8 |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
9 |
from torchvision.transforms.functional import InterpolationMode
|
10 |
from transformers import BertTokenizerFast
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
class Transform(torch.nn.Module):
|