Javi commited on
Commit
2d2805b
·
1 Parent(s): 714cf07

Bugfixes and content

Browse files
Files changed (1) hide show
  1. streamlit_app.py +49 -26
streamlit_app.py CHANGED
@@ -16,20 +16,20 @@ BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
16
 
17
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
18
  "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
19
- "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
20
- "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
21
- "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
22
- "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
23
  "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
24
  "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
25
  "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
26
  "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
27
- "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
28
  "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
29
  "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
30
  "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
31
  "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
32
- "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
33
  ]
34
 
35
  @st.cache # Cache this so that it doesn't change every time something changes in the page
@@ -54,12 +54,32 @@ class Sections:
54
  def header():
55
  st.markdown("# CLIP playground")
56
  st.markdown("### Try OpenAI's CLIP model in your browser")
57
- st.markdown(" ");
58
  st.markdown(" ")
59
  with st.beta_expander("What is CLIP?"):
60
- st.markdown("Nice CLIP explaination")
61
- st.markdown(" ");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  st.markdown(" ")
 
 
 
63
 
64
  @staticmethod
65
  def image_uploader(state: SessionState, accept_multiple_files: bool):
@@ -75,23 +95,26 @@ class Sections:
75
 
76
 
77
  @staticmethod
78
- def image_picker(state: SessionState):
79
  col1, col2, col3 = st.beta_columns(3)
80
  with col1:
81
  default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
82
  st.image(default_image_1, use_column_width=True)
83
  if st.button("Select image 1"):
84
  state.images = [default_image_1]
 
85
  with col2:
86
  default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
87
  st.image(default_image_2, use_column_width=True)
88
  if st.button("Select image 2"):
89
  state.images = [default_image_2]
 
90
  with col3:
91
  default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
92
  st.image(default_image_3, use_column_width=True)
93
  if st.button("Select image 3"):
94
  state.images = [default_image_3]
 
95
 
96
  @staticmethod
97
  def dataset_picker(state: SessionState):
@@ -105,17 +128,19 @@ class Sections:
105
  image_idx += 1
106
  if st.button("Select random dataset"):
107
  state.images = state.dataset
 
108
 
109
  @staticmethod
110
  def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
111
- raw_classes = st.text_input(input_label)
 
112
  if raw_classes:
113
  state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
114
- state.prompt_prefix = prompt_prefix
115
 
116
  @staticmethod
117
  def single_image_input_preview(state: SessionState):
118
- col1, col2 = st.beta_columns([2, 1])
 
119
  with col1:
120
  st.markdown("Image to classify")
121
  if state.images is not None:
@@ -127,7 +152,7 @@ class Sections:
127
  st.markdown("Labels to choose from")
128
  if state.prompts is not None:
129
  for prompt in state.prompts:
130
- st.markdown(f"* {prompt[len(state.prompt_prefix):]}")
131
  if len(state.prompts) < 2:
132
  st.warning("At least two prompts/classes are needed")
133
  else:
@@ -135,6 +160,7 @@ class Sections:
135
 
136
  @staticmethod
137
  def multiple_images_input_preview(state: SessionState):
 
138
  st.markdown("Images to classify")
139
  col1, col2, col3 = st.beta_columns(3)
140
  if state.images is not None:
@@ -148,27 +174,24 @@ class Sections:
148
  else:
149
  col1.warning("Select an image")
150
 
151
-
152
  with col3:
153
  st.markdown("Query prompt")
154
  if state.prompts is not None:
155
  for prompt in state.prompts:
156
- st.write(prompt[len(state.prompt_prefix):])
157
  else:
158
  st.warning("Enter the prompt to classify")
159
 
160
  @staticmethod
161
  def classification_output(state: SessionState):
162
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
163
- if st.button("Predict"):
164
  with st.spinner("Predicting..."):
165
  if isinstance(state.images[0], str):
166
- print("Regular call!")
167
  clip_response = booste.clip(BOOSTE_API_KEY,
168
  prompts=state.prompts,
169
  images=state.images)
170
  else:
171
- print("Hacky call!")
172
  images_mocker.calculate_image_id2image_lookup(state.images)
173
  images_mocker.start_mocking()
174
  clip_response = booste.clip(BOOSTE_API_KEY,
@@ -178,7 +201,7 @@ class Sections:
178
  st.markdown("### Results")
179
  # st.write(clip_response)
180
  if len(state.images) == 1:
181
- simplified_clip_results = [(prompt[len(state.prompt_prefix):],
182
  list(results.values())[0]["probabilityRelativeToPrompts"])
183
  for prompt, results in clip_response.items()]
184
  simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
@@ -205,25 +228,26 @@ class Sections:
205
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
206
 
207
 
208
- task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
209
  session_state = get_state()
 
210
  if task_name == "Image classification":
211
- Sections.header()
212
  Sections.image_uploader(session_state, accept_multiple_files=False)
213
  if session_state.images is None:
214
  st.markdown("or choose one from")
215
- Sections.image_picker(session_state)
216
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
217
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
218
  limit_number_images(session_state)
219
  Sections.single_image_input_preview(session_state)
220
  Sections.classification_output(session_state)
221
  elif task_name == "Prompt ranking":
222
- Sections.header()
223
  Sections.image_uploader(session_state, accept_multiple_files=False)
224
  if session_state.images is None:
225
  st.markdown("or choose one from")
226
- Sections.image_picker(session_state)
 
 
227
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
228
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
229
  Sections.prompts_input(session_state, input_label)
@@ -231,7 +255,6 @@ elif task_name == "Prompt ranking":
231
  Sections.single_image_input_preview(session_state)
232
  Sections.classification_output(session_state)
233
  elif task_name == "Image ranking":
234
- Sections.header()
235
  Sections.image_uploader(session_state, accept_multiple_files=True)
236
  if session_state.images is None or len(session_state.images) < 2:
237
  st.markdown("or use this random dataset")
 
16
 
17
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
18
  "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
19
+ # "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
20
+ # "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
21
+ # "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
22
+ # "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
23
  "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
24
  "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
25
  "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
26
  "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
27
+ # "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
28
  "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
29
  "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
30
  "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
31
  "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
32
+ # "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
33
  ]
34
 
35
  @st.cache # Cache this so that it doesn't change every time something changes in the page
 
54
  def header():
55
  st.markdown("# CLIP playground")
56
  st.markdown("### Try OpenAI's CLIP model in your browser")
57
+ st.markdown(" ")
58
  st.markdown(" ")
59
  with st.beta_expander("What is CLIP?"):
60
+ st.markdown("CLIP is a machine learning model that computes similarity between text "
61
+ "(also called prompts) and images. It has been trained on a dataset with millions of diverse"
62
+ " image-prompt pairs, which allows it to generalize to unseen examples."
63
+ " <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
64
+ unsafe_allow_html=True)
65
+ col1, col2 = st.beta_columns(2)
66
+ col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
67
+ col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
68
+ with st.beta_expander("What can CLIP do?"):
69
+ st.markdown("#### Prompt ranking")
70
+ st.markdown("Given different prompts and an image it will rank the different prompts based on the similarity"
71
+ " with what the image represents")
72
+ st.markdown("#### Image ranking")
73
+ st.markdown("Given different images and a prompt it will rank the different images based on the similarity"
74
+ " with what the prompt expresses")
75
+ st.markdown("#### Image classification")
76
+ st.markdown("Similar to prompt ranking, given a set of classes it can classify an image between them. "
77
+ "Think of [Hotdog/ Not hotdog](https://www.youtube.com/watch?v=pqTntG1RXSY&ab_channel=tvpromos) without any training. ")
78
+ st.markdown(" ")
79
  st.markdown(" ")
80
+ st.sidebar.markdown(" "); st.sidebar.markdown(" ")
81
+ st.sidebar.markdown("Created by [@JavierFnts](https://twitter.com/JavierFnts)")
82
+ st.sidebar.markdown("[How was CLIP playground created?](https://twitter.com/JavierFnts)")
83
 
84
  @staticmethod
85
  def image_uploader(state: SessionState, accept_multiple_files: bool):
 
95
 
96
 
97
  @staticmethod
98
+ def image_picker(state: SessionState, default_text_input: str):
99
  col1, col2, col3 = st.beta_columns(3)
100
  with col1:
101
  default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
102
  st.image(default_image_1, use_column_width=True)
103
  if st.button("Select image 1"):
104
  state.images = [default_image_1]
105
+ state.default_text_input = default_text_input
106
  with col2:
107
  default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
108
  st.image(default_image_2, use_column_width=True)
109
  if st.button("Select image 2"):
110
  state.images = [default_image_2]
111
+ state.default_text_input = default_text_input
112
  with col3:
113
  default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
114
  st.image(default_image_3, use_column_width=True)
115
  if st.button("Select image 3"):
116
  state.images = [default_image_3]
117
+ state.default_text_input = default_text_input
118
 
119
  @staticmethod
120
  def dataset_picker(state: SessionState):
 
128
  image_idx += 1
129
  if st.button("Select random dataset"):
130
  state.images = state.dataset
131
+ state.default_text_input = "A sign that says 'SLOW DOWN'"
132
 
133
  @staticmethod
134
  def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
135
+ raw_classes = st.text_input(input_label,
136
+ value=state.default_text_input if state.default_text_input is not None else "")
137
  if raw_classes:
138
  state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
 
139
 
140
  @staticmethod
141
  def single_image_input_preview(state: SessionState):
142
+ st.markdown("### Preview")
143
+ col1, col2 = st.beta_columns([1, 2])
144
  with col1:
145
  st.markdown("Image to classify")
146
  if state.images is not None:
 
152
  st.markdown("Labels to choose from")
153
  if state.prompts is not None:
154
  for prompt in state.prompts:
155
+ st.markdown(f"* {prompt}")
156
  if len(state.prompts) < 2:
157
  st.warning("At least two prompts/classes are needed")
158
  else:
 
160
 
161
  @staticmethod
162
  def multiple_images_input_preview(state: SessionState):
163
+ st.markdown("### Preview")
164
  st.markdown("Images to classify")
165
  col1, col2, col3 = st.beta_columns(3)
166
  if state.images is not None:
 
174
  else:
175
  col1.warning("Select an image")
176
 
 
177
  with col3:
178
  st.markdown("Query prompt")
179
  if state.prompts is not None:
180
  for prompt in state.prompts:
181
+ st.write(prompt)
182
  else:
183
  st.warning("Enter the prompt to classify")
184
 
185
  @staticmethod
186
  def classification_output(state: SessionState):
187
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
188
+ if st.button("PREDICT 🚀"):
189
  with st.spinner("Predicting..."):
190
  if isinstance(state.images[0], str):
 
191
  clip_response = booste.clip(BOOSTE_API_KEY,
192
  prompts=state.prompts,
193
  images=state.images)
194
  else:
 
195
  images_mocker.calculate_image_id2image_lookup(state.images)
196
  images_mocker.start_mocking()
197
  clip_response = booste.clip(BOOSTE_API_KEY,
 
201
  st.markdown("### Results")
202
  # st.write(clip_response)
203
  if len(state.images) == 1:
204
+ simplified_clip_results = [(prompt,
205
  list(results.values())[0]["probabilityRelativeToPrompts"])
206
  for prompt, results in clip_response.items()]
207
  simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
 
228
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
229
 
230
 
231
+ task_name: str = st.sidebar.radio("Task", options=["Prompt ranking", "Image ranking", "Image classification"])
232
  session_state = get_state()
233
+ Sections.header()
234
  if task_name == "Image classification":
 
235
  Sections.image_uploader(session_state, accept_multiple_files=False)
236
  if session_state.images is None:
237
  st.markdown("or choose one from")
238
+ Sections.image_picker(session_state, default_text_input="banana; boat; bird")
239
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
240
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
241
  limit_number_images(session_state)
242
  Sections.single_image_input_preview(session_state)
243
  Sections.classification_output(session_state)
244
  elif task_name == "Prompt ranking":
 
245
  Sections.image_uploader(session_state, accept_multiple_files=False)
246
  if session_state.images is None:
247
  st.markdown("or choose one from")
248
+ Sections.image_picker(session_state, default_text_input="A calm afternoon in the Mediterranean; "
249
+ "A beautiful creature;"
250
+ " Something that grows in tropical regions")
251
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
252
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
253
  Sections.prompts_input(session_state, input_label)
 
255
  Sections.single_image_input_preview(session_state)
256
  Sections.classification_output(session_state)
257
  elif task_name == "Image ranking":
 
258
  Sections.image_uploader(session_state, accept_multiple_files=True)
259
  if session_state.images is None or len(session_state.images) < 2:
260
  st.markdown("or use this random dataset")