ariG23498 HF staff ceyda commited on
Commit
fc4cd1f
1 Parent(s): 85a0494

add file uploader (#1)

Browse files

- add file uploader (19e4e13690edc6892f4f5c2cdd8ffb996a24ebc8)
- cache the model so it doesn't reload for every image (bc751867842b5ed8f9a682576791de8c80c2f834)


Co-authored-by: Ceyda Cinarel <[email protected]>

Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -4,6 +4,14 @@ from PIL import Image
4
  import streamlit as st
5
  import tensorflow as tf
6
 
 
 
 
 
 
 
 
 
7
  # Inputs
8
  st.title("Input your image")
9
  image_url = st.text_input(
@@ -11,6 +19,7 @@ image_url = st.text_input(
11
  value="https://dl.fbaipublicfiles.com/dino/img.png",
12
  placeholder="https://your-favourite-image.png"
13
  )
 
14
 
15
  # Outputs
16
  st.title("Original Image from URL")
@@ -20,12 +29,12 @@ image, preprocessed_image = utils.load_image_from_url(
20
  image_url,
21
  model_type="dino"
22
  )
 
 
 
 
23
  st.image(image, caption="Original Image")
24
 
25
- # Load the DINO model
26
- with st.spinner("Loading the model..."):
27
- dino = from_pretrained_keras("probing-vits/vit-dino-base16")
28
-
29
  with st.spinner("Generating the attention scores..."):
30
  # Get the attention scores
31
  _, attention_score_dict = dino.predict(preprocessed_image)
 
4
  import streamlit as st
5
  import tensorflow as tf
6
 
7
+ st.cache(show_spinner=True)
8
+ def load_model():
9
+ # Load the DINO model
10
+ dino = from_pretrained_keras("probing-vits/vit-dino-base16")
11
+ return dino
12
+
13
+ dino=load_model()
14
+
15
  # Inputs
16
  st.title("Input your image")
17
  image_url = st.text_input(
 
19
  value="https://dl.fbaipublicfiles.com/dino/img.png",
20
  placeholder="https://your-favourite-image.png"
21
  )
22
+ uploaded_files = st.file_uploader("or an image file", type =["jpg","jpeg"])
23
 
24
  # Outputs
25
  st.title("Original Image from URL")
 
29
  image_url,
30
  model_type="dino"
31
  )
32
+ if uploaded_file:
33
+ image = Image.open(im)
34
+ preprocessed_image = utils.preprocess_image(image, model_type)
35
+
36
  st.image(image, caption="Original Image")
37
 
 
 
 
 
38
  with st.spinner("Generating the attention scores..."):
39
  # Get the attention scores
40
  _, attention_score_dict = dino.predict(preprocessed_image)