Spaces:
Runtime error
Runtime error
Commit
Β·
40497e6
1
Parent(s):
0454d20
Base functionality working again π
Browse files- app.py +45 -47
- clip_model.py +16 -3
app.py
CHANGED
|
@@ -230,8 +230,7 @@ class Sections:
|
|
| 230 |
|
| 231 |
@staticmethod
|
| 232 |
def classification_output(model: ClipModel):
|
| 233 |
-
|
| 234 |
-
if st.button("Predict") and is_valid_prediction_state(): # PREDICT π
|
| 235 |
with st.spinner("Predicting..."):
|
| 236 |
|
| 237 |
st.markdown("### Results")
|
|
@@ -247,7 +246,6 @@ class Sections:
|
|
| 247 |
st.markdown(f"### {st.session_state.prompts[0]}")
|
| 248 |
|
| 249 |
scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
|
| 250 |
-
st.json(scores)
|
| 251 |
scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
|
| 252 |
sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
|
| 253 |
|
|
@@ -272,47 +270,47 @@ class Sections:
|
|
| 272 |
# " It can be whatever you can think of",
|
| 273 |
# unsafe_allow_html=True)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
Sections.header()
|
| 277 |
-
col1, col2 = st.columns([1, 2])
|
| 278 |
-
col1.markdown(" "); col1.markdown(" ")
|
| 279 |
-
col1.markdown("#### Task selection")
|
| 280 |
-
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
| 281 |
-
st.markdown("<br>", unsafe_allow_html=True)
|
| 282 |
-
init_state()
|
| 283 |
-
model = load_model()
|
| 284 |
-
if task_name == "Image classification":
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
elif task_name == "Prompt ranking":
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
elif task_name == "Image ranking":
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
| 318 |
-
|
|
|
|
| 230 |
|
| 231 |
@staticmethod
|
| 232 |
def classification_output(model: ClipModel):
|
| 233 |
+
if st.button("Predict") and is_valid_prediction_state():
|
|
|
|
| 234 |
with st.spinner("Predicting..."):
|
| 235 |
|
| 236 |
st.markdown("### Results")
|
|
|
|
| 246 |
st.markdown(f"### {st.session_state.prompts[0]}")
|
| 247 |
|
| 248 |
scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
|
|
|
|
| 249 |
scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
|
| 250 |
sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
|
| 251 |
|
|
|
|
| 270 |
# " It can be whatever you can think of",
|
| 271 |
# unsafe_allow_html=True)
|
| 272 |
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
Sections.header()
|
| 275 |
+
col1, col2 = st.columns([1, 2])
|
| 276 |
+
col1.markdown(" "); col1.markdown(" ")
|
| 277 |
+
col1.markdown("#### Task selection")
|
| 278 |
+
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
| 279 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 280 |
+
init_state()
|
| 281 |
+
model = load_model()
|
| 282 |
+
if task_name == "Image classification":
|
| 283 |
+
Sections.image_uploader(accept_multiple_files=False)
|
| 284 |
+
if st.session_state.images is None:
|
| 285 |
+
st.markdown("or choose one from")
|
| 286 |
+
Sections.image_picker(default_text_input="banana; boat; bird")
|
| 287 |
+
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
| 288 |
+
Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
|
| 289 |
+
limit_number_images()
|
| 290 |
+
Sections.single_image_input_preview()
|
| 291 |
+
Sections.classification_output(model)
|
| 292 |
+
elif task_name == "Prompt ranking":
|
| 293 |
+
Sections.image_uploader(accept_multiple_files=False)
|
| 294 |
+
if st.session_state.images is None:
|
| 295 |
+
st.markdown("or choose one from")
|
| 296 |
+
Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
|
| 297 |
+
"A beautiful creature;"
|
| 298 |
+
" Something that grows in tropical regions")
|
| 299 |
+
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
| 300 |
+
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
| 301 |
+
Sections.prompts_input(input_label)
|
| 302 |
+
limit_number_images()
|
| 303 |
+
Sections.single_image_input_preview()
|
| 304 |
+
Sections.classification_output(model)
|
| 305 |
+
elif task_name == "Image ranking":
|
| 306 |
+
Sections.image_uploader(accept_multiple_files=True)
|
| 307 |
+
if st.session_state.images is None or len(st.session_state.images) < 2:
|
| 308 |
+
st.markdown("or use this random dataset")
|
| 309 |
+
Sections.dataset_picker()
|
| 310 |
+
Sections.prompts_input("Enter the prompt to query the images by")
|
| 311 |
+
limit_number_prompts()
|
| 312 |
+
Sections.multiple_images_input_preview()
|
| 313 |
+
Sections.classification_output(model)
|
| 314 |
+
|
| 315 |
+
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
| 316 |
+
"", unsafe_allow_html=True)
|
clip_model.py
CHANGED
|
@@ -2,6 +2,8 @@ import clip
|
|
| 2 |
from PIL.Image import Image
|
| 3 |
import torch
|
| 4 |
|
|
|
|
|
|
|
| 5 |
class ClipModel:
|
| 6 |
def __init__(self, model_name: str = 'RN50') -> None:
|
| 7 |
"""
|
|
@@ -42,7 +44,7 @@ class ClipModel:
|
|
| 42 |
preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
|
| 43 |
tokenized_prompts = clip.tokenize(prompt)
|
| 44 |
with torch.inference_mode():
|
| 45 |
-
image_features = self._model.encode_image(
|
| 46 |
text_features = self._model.encode_text(tokenized_prompts)
|
| 47 |
|
| 48 |
# normalized features
|
|
@@ -51,8 +53,19 @@ class ClipModel:
|
|
| 51 |
|
| 52 |
# cosine similarity as logits
|
| 53 |
logit_scale = self._model.logit_scale.exp()
|
| 54 |
-
logits_per_image = logit_scale *
|
| 55 |
|
| 56 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
| 57 |
|
| 58 |
-
return probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from PIL.Image import Image
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
|
| 6 |
+
|
| 7 |
class ClipModel:
|
| 8 |
def __init__(self, model_name: str = 'RN50') -> None:
|
| 9 |
"""
|
|
|
|
| 44 |
preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
|
| 45 |
tokenized_prompts = clip.tokenize(prompt)
|
| 46 |
with torch.inference_mode():
|
| 47 |
+
image_features = torch.cat([self._model.encode_image(preprocessed_image) for preprocessed_image in preprocessed_images])
|
| 48 |
text_features = self._model.encode_text(tokenized_prompts)
|
| 49 |
|
| 50 |
# normalized features
|
|
|
|
| 53 |
|
| 54 |
# cosine similarity as logits
|
| 55 |
logit_scale = self._model.logit_scale.exp()
|
| 56 |
+
logits_per_image = logit_scale * text_features @ image_features.t()
|
| 57 |
|
| 58 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
| 59 |
|
| 60 |
+
return probs
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
from app import load_default_dataset
|
| 64 |
+
|
| 65 |
+
model = ClipModel()
|
| 66 |
+
images = load_default_dataset()
|
| 67 |
+
prompts = ['Hello', 'How are you', 'Goodbye']
|
| 68 |
+
prompts_scores = model.compute_prompts_probabilities(images[0], prompts)
|
| 69 |
+
images_scores = model.compute_images_probabilities(images, prompts[0])
|
| 70 |
+
print(f"Prompts scores: {prompts_scores}")
|
| 71 |
+
print(f"Images scores: {images_scores}")
|