Spaces:
Runtime error
Runtime error
File size: 7,756 Bytes
ee0cae7 0d3a066 ee0cae7 d3a50f6 dbf8f9d 971e64d ee0cae7 f7512c4 ee0cae7 0c337bd ee0cae7 f7512c4 ee0cae7 971e64d 2e7f326 971e64d ee0cae7 741aba6 ee0cae7 46228a6 ee0cae7 bee7306 31a5db6 bf143f5 dbf8f9d 0c337bd 0d38b2c e6254c7 0d38b2c 971e64d 0d38b2c 31a5db6 a801546 0d38b2c 0c337bd 0d38b2c bf143f5 6218a98 2e7f326 ee0cae7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from os import environ
import clip
import pickle
import requests
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load the pre-trained model and processor
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
height = 384 # height for resizing images
def predict(image, labels):
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return {k: float(v) for k, v in zip(labels, probs[0])}
def predict2(image, labels):
image = orig_clip_processor(image).unsqueeze(0).to(device)
text = clip.tokenize(labels).to(device)
with torch.no_grad():
image_features = orig_clip_model.encode_image(image)
text_features = orig_clip_model.encode_text(text)
logits_per_image, logits_per_text = orig_clip_model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return {k: float(v) for k, v in zip(labels, probs[0])}
def rand_image():
n = dataset.num_rows
r = random.randrange(0,n)
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
def set_labels(text):
return text.split(",")
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
def generate_text(image, model_name):
return get_caption(image, model_name)
# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
# def search_images(text):
# return get_images(text, api_name="images")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP
text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
imgs.append(img)
#display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
if len(imgs) == 5: break
return imgs
with gr.Blocks(css="#caption-text {font-size: 40px !important;}") as demo:
with gr.Tab("Zero-Shot Classification"):
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
instructions = """## Instructions:
1. Enter list of labels separated by commas (or select one of the examples below)
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
gr.Markdown(instructions)
with gr.Row(variant="compact"):
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
#submit_btn = gr.Button("Submit").style(full_width=False)
gr.Examples(["spring, summer, fall, winter",
"mountain, city, beach, ocean, desert, forest, valley",
"red, blue, green, white, black, purple, brown",
"person, animal, landscape, something else",
"day, night, dawn, dusk"], inputs=label_text)
with gr.Row():
with gr.Column(variant="panel"):
im = gr.Image(interactive=False, type="pil").style(height=height)
with gr.Row():
get_btn = gr.Button("Get Random Image").style(full_width=False)
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
cf = gr.Label()
#submit_btn.click(fn=set_labels, inputs=label_text)
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
get_btn.click(fn=rand_image, outputs=im)
im.change(predict2, inputs=[im, labels], outputs=cf)
reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)
with gr.Tab("Image Captioning"):
with gr.Row():
with gr.Column(variant="panel"):
im_cap = gr.Image(interactive=False, type='filepath').style(height=height)
model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False)
with gr.Row():
get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
caption_btn = gr.Button("Create Caption").style(full_width=False)
caption = gr.Textbox(label='Caption', elem_id="caption-text")
get_btn_cap.click(fn=rand_image, outputs=im_cap)
#im_cap.change(generate_text, inputs=im_cap, outputs=caption)
caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption)
with gr.Tab("Image Search"):
with gr.Column(variant="panel"):
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
search_btn = gr.Button("Find Images").style(full_width=False)
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False)
demo.launch() |