Spaces:
Runtime error
Runtime error
Javi
commited on
Commit
·
2d2805b
1
Parent(s):
714cf07
Bugfixes and content
Browse files- streamlit_app.py +49 -26
streamlit_app.py
CHANGED
@@ -16,20 +16,20 @@ BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
|
16 |
|
17 |
IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
|
18 |
"https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
|
19 |
-
"https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
|
20 |
-
"https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
|
21 |
-
"https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
|
22 |
-
"https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
|
23 |
"https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
|
24 |
"https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
|
25 |
"https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
|
26 |
"https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
|
27 |
-
"https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
|
28 |
"https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
|
29 |
"https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
|
30 |
"https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
|
31 |
"https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
|
32 |
-
"https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
|
33 |
]
|
34 |
|
35 |
@st.cache # Cache this so that it doesn't change every time something changes in the page
|
@@ -54,12 +54,32 @@ class Sections:
|
|
54 |
def header():
|
55 |
st.markdown("# CLIP playground")
|
56 |
st.markdown("### Try OpenAI's CLIP model in your browser")
|
57 |
-
st.markdown(" ")
|
58 |
st.markdown(" ")
|
59 |
with st.beta_expander("What is CLIP?"):
|
60 |
-
st.markdown("
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
st.markdown(" ")
|
|
|
|
|
|
|
63 |
|
64 |
@staticmethod
|
65 |
def image_uploader(state: SessionState, accept_multiple_files: bool):
|
@@ -75,23 +95,26 @@ class Sections:
|
|
75 |
|
76 |
|
77 |
@staticmethod
|
78 |
-
def image_picker(state: SessionState):
|
79 |
col1, col2, col3 = st.beta_columns(3)
|
80 |
with col1:
|
81 |
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
82 |
st.image(default_image_1, use_column_width=True)
|
83 |
if st.button("Select image 1"):
|
84 |
state.images = [default_image_1]
|
|
|
85 |
with col2:
|
86 |
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
87 |
st.image(default_image_2, use_column_width=True)
|
88 |
if st.button("Select image 2"):
|
89 |
state.images = [default_image_2]
|
|
|
90 |
with col3:
|
91 |
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
92 |
st.image(default_image_3, use_column_width=True)
|
93 |
if st.button("Select image 3"):
|
94 |
state.images = [default_image_3]
|
|
|
95 |
|
96 |
@staticmethod
|
97 |
def dataset_picker(state: SessionState):
|
@@ -105,17 +128,19 @@ class Sections:
|
|
105 |
image_idx += 1
|
106 |
if st.button("Select random dataset"):
|
107 |
state.images = state.dataset
|
|
|
108 |
|
109 |
@staticmethod
|
110 |
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
|
111 |
-
raw_classes = st.text_input(input_label
|
|
|
112 |
if raw_classes:
|
113 |
state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
|
114 |
-
state.prompt_prefix = prompt_prefix
|
115 |
|
116 |
@staticmethod
|
117 |
def single_image_input_preview(state: SessionState):
|
118 |
-
|
|
|
119 |
with col1:
|
120 |
st.markdown("Image to classify")
|
121 |
if state.images is not None:
|
@@ -127,7 +152,7 @@ class Sections:
|
|
127 |
st.markdown("Labels to choose from")
|
128 |
if state.prompts is not None:
|
129 |
for prompt in state.prompts:
|
130 |
-
st.markdown(f"* {prompt
|
131 |
if len(state.prompts) < 2:
|
132 |
st.warning("At least two prompts/classes are needed")
|
133 |
else:
|
@@ -135,6 +160,7 @@ class Sections:
|
|
135 |
|
136 |
@staticmethod
|
137 |
def multiple_images_input_preview(state: SessionState):
|
|
|
138 |
st.markdown("Images to classify")
|
139 |
col1, col2, col3 = st.beta_columns(3)
|
140 |
if state.images is not None:
|
@@ -148,27 +174,24 @@ class Sections:
|
|
148 |
else:
|
149 |
col1.warning("Select an image")
|
150 |
|
151 |
-
|
152 |
with col3:
|
153 |
st.markdown("Query prompt")
|
154 |
if state.prompts is not None:
|
155 |
for prompt in state.prompts:
|
156 |
-
st.write(prompt
|
157 |
else:
|
158 |
st.warning("Enter the prompt to classify")
|
159 |
|
160 |
@staticmethod
|
161 |
def classification_output(state: SessionState):
|
162 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
163 |
-
if st.button("
|
164 |
with st.spinner("Predicting..."):
|
165 |
if isinstance(state.images[0], str):
|
166 |
-
print("Regular call!")
|
167 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
168 |
prompts=state.prompts,
|
169 |
images=state.images)
|
170 |
else:
|
171 |
-
print("Hacky call!")
|
172 |
images_mocker.calculate_image_id2image_lookup(state.images)
|
173 |
images_mocker.start_mocking()
|
174 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
@@ -178,7 +201,7 @@ class Sections:
|
|
178 |
st.markdown("### Results")
|
179 |
# st.write(clip_response)
|
180 |
if len(state.images) == 1:
|
181 |
-
simplified_clip_results = [(prompt
|
182 |
list(results.values())[0]["probabilityRelativeToPrompts"])
|
183 |
for prompt, results in clip_response.items()]
|
184 |
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
@@ -205,25 +228,26 @@ class Sections:
|
|
205 |
col2.markdown(f"### ")
|
206 |
|
207 |
|
208 |
-
task_name: str = st.sidebar.radio("Task", options=["
|
209 |
session_state = get_state()
|
|
|
210 |
if task_name == "Image classification":
|
211 |
-
Sections.header()
|
212 |
Sections.image_uploader(session_state, accept_multiple_files=False)
|
213 |
if session_state.images is None:
|
214 |
st.markdown("or choose one from")
|
215 |
-
Sections.image_picker(session_state)
|
216 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
217 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
218 |
limit_number_images(session_state)
|
219 |
Sections.single_image_input_preview(session_state)
|
220 |
Sections.classification_output(session_state)
|
221 |
elif task_name == "Prompt ranking":
|
222 |
-
Sections.header()
|
223 |
Sections.image_uploader(session_state, accept_multiple_files=False)
|
224 |
if session_state.images is None:
|
225 |
st.markdown("or choose one from")
|
226 |
-
Sections.image_picker(session_state
|
|
|
|
|
227 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
228 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
229 |
Sections.prompts_input(session_state, input_label)
|
@@ -231,7 +255,6 @@ elif task_name == "Prompt ranking":
|
|
231 |
Sections.single_image_input_preview(session_state)
|
232 |
Sections.classification_output(session_state)
|
233 |
elif task_name == "Image ranking":
|
234 |
-
Sections.header()
|
235 |
Sections.image_uploader(session_state, accept_multiple_files=True)
|
236 |
if session_state.images is None or len(session_state.images) < 2:
|
237 |
st.markdown("or use this random dataset")
|
|
|
16 |
|
17 |
IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
|
18 |
"https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
|
19 |
+
# "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
|
20 |
+
# "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
|
21 |
+
# "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
|
22 |
+
# "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
|
23 |
"https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
|
24 |
"https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
|
25 |
"https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
|
26 |
"https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
|
27 |
+
# "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
|
28 |
"https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
|
29 |
"https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
|
30 |
"https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
|
31 |
"https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
|
32 |
+
# "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
|
33 |
]
|
34 |
|
35 |
@st.cache # Cache this so that it doesn't change every time something changes in the page
|
|
|
54 |
def header():
|
55 |
st.markdown("# CLIP playground")
|
56 |
st.markdown("### Try OpenAI's CLIP model in your browser")
|
57 |
+
st.markdown(" ")
|
58 |
st.markdown(" ")
|
59 |
with st.beta_expander("What is CLIP?"):
|
60 |
+
st.markdown("CLIP is a machine learning model that computes similarity between text "
|
61 |
+
"(also called prompts) and images. It has been trained on a dataset with millions of diverse"
|
62 |
+
" image-prompt pairs, which allows it to generalize to unseen examples."
|
63 |
+
" <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
|
64 |
+
unsafe_allow_html=True)
|
65 |
+
col1, col2 = st.beta_columns(2)
|
66 |
+
col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
|
67 |
+
col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
|
68 |
+
with st.beta_expander("What can CLIP do?"):
|
69 |
+
st.markdown("#### Prompt ranking")
|
70 |
+
st.markdown("Given different prompts and an image it will rank the different prompts based on the similarity"
|
71 |
+
" with what the image represents")
|
72 |
+
st.markdown("#### Image ranking")
|
73 |
+
st.markdown("Given different images and a prompt it will rank the different images based on the similarity"
|
74 |
+
" with what the prompt expresses")
|
75 |
+
st.markdown("#### Image classification")
|
76 |
+
st.markdown("Similar to prompt ranking, given a set of classes it can classify an image between them. "
|
77 |
+
"Think of [Hotdog/ Not hotdog](https://www.youtube.com/watch?v=pqTntG1RXSY&ab_channel=tvpromos) without any training. ")
|
78 |
+
st.markdown(" ")
|
79 |
st.markdown(" ")
|
80 |
+
st.sidebar.markdown(" "); st.sidebar.markdown(" ")
|
81 |
+
st.sidebar.markdown("Created by [@JavierFnts](https://twitter.com/JavierFnts)")
|
82 |
+
st.sidebar.markdown("[How was CLIP playground created?](https://twitter.com/JavierFnts)")
|
83 |
|
84 |
@staticmethod
|
85 |
def image_uploader(state: SessionState, accept_multiple_files: bool):
|
|
|
95 |
|
96 |
|
97 |
@staticmethod
|
98 |
+
def image_picker(state: SessionState, default_text_input: str):
|
99 |
col1, col2, col3 = st.beta_columns(3)
|
100 |
with col1:
|
101 |
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
102 |
st.image(default_image_1, use_column_width=True)
|
103 |
if st.button("Select image 1"):
|
104 |
state.images = [default_image_1]
|
105 |
+
state.default_text_input = default_text_input
|
106 |
with col2:
|
107 |
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
108 |
st.image(default_image_2, use_column_width=True)
|
109 |
if st.button("Select image 2"):
|
110 |
state.images = [default_image_2]
|
111 |
+
state.default_text_input = default_text_input
|
112 |
with col3:
|
113 |
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
114 |
st.image(default_image_3, use_column_width=True)
|
115 |
if st.button("Select image 3"):
|
116 |
state.images = [default_image_3]
|
117 |
+
state.default_text_input = default_text_input
|
118 |
|
119 |
@staticmethod
|
120 |
def dataset_picker(state: SessionState):
|
|
|
128 |
image_idx += 1
|
129 |
if st.button("Select random dataset"):
|
130 |
state.images = state.dataset
|
131 |
+
state.default_text_input = "A sign that says 'SLOW DOWN'"
|
132 |
|
133 |
@staticmethod
|
134 |
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
|
135 |
+
raw_classes = st.text_input(input_label,
|
136 |
+
value=state.default_text_input if state.default_text_input is not None else "")
|
137 |
if raw_classes:
|
138 |
state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
|
|
|
139 |
|
140 |
@staticmethod
|
141 |
def single_image_input_preview(state: SessionState):
|
142 |
+
st.markdown("### Preview")
|
143 |
+
col1, col2 = st.beta_columns([1, 2])
|
144 |
with col1:
|
145 |
st.markdown("Image to classify")
|
146 |
if state.images is not None:
|
|
|
152 |
st.markdown("Labels to choose from")
|
153 |
if state.prompts is not None:
|
154 |
for prompt in state.prompts:
|
155 |
+
st.markdown(f"* {prompt}")
|
156 |
if len(state.prompts) < 2:
|
157 |
st.warning("At least two prompts/classes are needed")
|
158 |
else:
|
|
|
160 |
|
161 |
@staticmethod
|
162 |
def multiple_images_input_preview(state: SessionState):
|
163 |
+
st.markdown("### Preview")
|
164 |
st.markdown("Images to classify")
|
165 |
col1, col2, col3 = st.beta_columns(3)
|
166 |
if state.images is not None:
|
|
|
174 |
else:
|
175 |
col1.warning("Select an image")
|
176 |
|
|
|
177 |
with col3:
|
178 |
st.markdown("Query prompt")
|
179 |
if state.prompts is not None:
|
180 |
for prompt in state.prompts:
|
181 |
+
st.write(prompt)
|
182 |
else:
|
183 |
st.warning("Enter the prompt to classify")
|
184 |
|
185 |
@staticmethod
|
186 |
def classification_output(state: SessionState):
|
187 |
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
188 |
+
if st.button("PREDICT 🚀"):
|
189 |
with st.spinner("Predicting..."):
|
190 |
if isinstance(state.images[0], str):
|
|
|
191 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
192 |
prompts=state.prompts,
|
193 |
images=state.images)
|
194 |
else:
|
|
|
195 |
images_mocker.calculate_image_id2image_lookup(state.images)
|
196 |
images_mocker.start_mocking()
|
197 |
clip_response = booste.clip(BOOSTE_API_KEY,
|
|
|
201 |
st.markdown("### Results")
|
202 |
# st.write(clip_response)
|
203 |
if len(state.images) == 1:
|
204 |
+
simplified_clip_results = [(prompt,
|
205 |
list(results.values())[0]["probabilityRelativeToPrompts"])
|
206 |
for prompt, results in clip_response.items()]
|
207 |
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
|
|
228 |
col2.markdown(f"### ")
|
229 |
|
230 |
|
231 |
+
task_name: str = st.sidebar.radio("Task", options=["Prompt ranking", "Image ranking", "Image classification"])
|
232 |
session_state = get_state()
|
233 |
+
Sections.header()
|
234 |
if task_name == "Image classification":
|
|
|
235 |
Sections.image_uploader(session_state, accept_multiple_files=False)
|
236 |
if session_state.images is None:
|
237 |
st.markdown("or choose one from")
|
238 |
+
Sections.image_picker(session_state, default_text_input="banana; boat; bird")
|
239 |
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
240 |
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
241 |
limit_number_images(session_state)
|
242 |
Sections.single_image_input_preview(session_state)
|
243 |
Sections.classification_output(session_state)
|
244 |
elif task_name == "Prompt ranking":
|
|
|
245 |
Sections.image_uploader(session_state, accept_multiple_files=False)
|
246 |
if session_state.images is None:
|
247 |
st.markdown("or choose one from")
|
248 |
+
Sections.image_picker(session_state, default_text_input="A calm afternoon in the Mediterranean; "
|
249 |
+
"A beautiful creature;"
|
250 |
+
" Something that grows in tropical regions")
|
251 |
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
252 |
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
253 |
Sections.prompts_input(session_state, input_label)
|
|
|
255 |
Sections.single_image_input_preview(session_state)
|
256 |
Sections.classification_output(session_state)
|
257 |
elif task_name == "Image ranking":
|
|
|
258 |
Sections.image_uploader(session_state, accept_multiple_files=True)
|
259 |
if session_state.images is None or len(session_state.images) < 2:
|
260 |
st.markdown("or use this random dataset")
|