Spaces:
Runtime error
Runtime error
Commit
·
0cb8576
1
Parent(s):
f15eef4
Add models to state
Browse files- apps/mlm.py +5 -5
- apps/vqa.py +5 -4
apps/mlm.py
CHANGED
@@ -27,7 +27,7 @@ def app(state):
|
|
27 |
|
28 |
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
29 |
def predict(transformed_image, caption_inputs):
|
30 |
-
outputs = model(pixel_values=transformed_image, **caption_inputs)
|
31 |
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
|
32 |
1
|
33 |
][0]
|
@@ -56,10 +56,10 @@ def app(state):
|
|
56 |
image = plt.imread(image_path)
|
57 |
mlm_state.mlm_image = image
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
if st.button(
|
65 |
"Get a random example",
|
|
|
27 |
|
28 |
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
29 |
def predict(transformed_image, caption_inputs):
|
30 |
+
outputs = mlm_state.model(pixel_values=transformed_image, **caption_inputs)
|
31 |
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
|
32 |
1
|
33 |
][0]
|
|
|
56 |
image = plt.imread(image_path)
|
57 |
mlm_state.mlm_image = image
|
58 |
|
59 |
+
if mlm_state.model is None:
|
60 |
+
# Display Top-5 Predictions
|
61 |
+
with st.spinner("Loading model..."):
|
62 |
+
mlm_state.model = load_model(mlm_checkpoints[0])
|
63 |
|
64 |
if st.button(
|
65 |
"Get a random example",
|
apps/vqa.py
CHANGED
@@ -31,7 +31,7 @@ def app(state):
|
|
31 |
# @st.cache(persist=False)
|
32 |
def predict(transformed_image, question_inputs):
|
33 |
return np.array(
|
34 |
-
model(pixel_values=transformed_image, **question_inputs)[0][0]
|
35 |
)
|
36 |
|
37 |
# @st.cache(persist=False)
|
@@ -65,11 +65,12 @@ def app(state):
|
|
65 |
image = plt.imread(image_path)
|
66 |
vqa_state.vqa_image = image
|
67 |
|
68 |
-
|
|
|
|
|
69 |
|
70 |
# Display Top-5 Predictions
|
71 |
-
|
72 |
-
model = load_model(vqa_checkpoints[0])
|
73 |
|
74 |
if st.button(
|
75 |
"Get a random example",
|
|
|
31 |
# @st.cache(persist=False)
|
32 |
def predict(transformed_image, question_inputs):
|
33 |
return np.array(
|
34 |
+
vqa_state.model(pixel_values=transformed_image, **question_inputs)[0][0]
|
35 |
)
|
36 |
|
37 |
# @st.cache(persist=False)
|
|
|
65 |
image = plt.imread(image_path)
|
66 |
vqa_state.vqa_image = image
|
67 |
|
68 |
+
if vqa_state.model is None:
|
69 |
+
with st.spinner("Loading model..."):
|
70 |
+
vqa_state.model = load_model(vqa_checkpoints[0])
|
71 |
|
72 |
# Display Top-5 Predictions
|
73 |
+
|
|
|
74 |
|
75 |
if st.button(
|
76 |
"Get a random example",
|