Spaces:
Runtime error
Runtime error
File size: 8,013 Bytes
ee0cae7 0d3a066 ee0cae7 d3a50f6 dbf8f9d 971e64d 4150f63 971e64d 54beb65 ee0cae7 f7512c4 ee0cae7 7b2aaba ee0cae7 f7512c4 ee0cae7 971e64d 54beb65 971e64d ee0cae7 741aba6 ee0cae7 46228a6 ee0cae7 bee7306 31a5db6 bf143f5 dbf8f9d 4150f63 54beb65 4150f63 dbf8f9d 7b2aaba 0d38b2c 54beb65 0d38b2c 4150f63 0d38b2c a3cfe93 a801546 0d38b2c 7b2aaba 0d38b2c bf143f5 6218a98 2e7f326 ee0cae7 |
|
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from os import environ
import clip
import pickle
import requests
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load the pre-trained model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
#orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
height = 256 # height for resizing images
def predict(image, labels):
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return {k: float(v) for k, v in zip(labels, probs[0])}
# def predict2(image, labels):
# image = orig_clip_processor(image).unsqueeze(0).to(device)
# text = clip.tokenize(labels).to(device)
# with torch.no_grad():
# image_features = orig_clip_model.encode_image(image)
# text_features = orig_clip_model.encode_text(text)
# logits_per_image, logits_per_text = orig_clip_model(image, text)
# probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# return {k: float(v) for k, v in zip(labels, probs[0])}
def rand_image():
n = dataset.num_rows
r = random.randrange(0,n)
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
def set_labels(text):
return text.split(",")
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
def generate_text(image, model_name):
return get_caption(image, model_name)
# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
# def search_images(text):
# return get_images(text, api_name="images")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP (HF CLIP)
inputs = processor(text=search_query, images=None, return_tensors="pt", padding=True)
text_encoded = model.get_text_features(**inputs)
# # Encode and normalize the description using CLIP (original CLIP)
# text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
# text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
imgs.append(img)
#display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
if len(imgs) == 5: break
return imgs
with gr.Blocks(css=".caption-text {font-size: 40px !important;}") as demo:
with gr.Tab("Zero-Shot Classification"):
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
instructions = """## Instructions:
1. Enter list of labels separated by commas (or select one of the examples below)
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
gr.Markdown(instructions)
with gr.Row(variant="compact"):
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
#submit_btn = gr.Button("Submit").style(full_width=False)
gr.Examples(["spring, summer, fall, winter",
"mountain, city, beach, ocean, desert, forest, valley",
"red, blue, green, white, black, purple, brown",
"person, animal, landscape, something else",
"day, night, dawn, dusk"], inputs=label_text)
with gr.Row():
with gr.Column(variant="panel"):
im = gr.Image(interactive=False).style(height=height)
with gr.Row():
get_btn = gr.Button("Get Random Image").style(full_width=False)
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
cf = gr.Label()
#submit_btn.click(fn=set_labels, inputs=label_text)
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
get_btn.click(fn=rand_image, outputs=im)
im.change(predict, inputs=[im, labels], outputs=cf)
reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
with gr.Tab("Image Captioning"):
with gr.Row():
with gr.Column(variant="panel"):
im_cap = gr.Image(interactive=False, type='filepath').style(height=height)
model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False)
with gr.Row():
get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
caption_btn = gr.Button("Create Caption").style(full_width=False)
caption = gr.Textbox(label='Caption', elem_classes="caption-text")
get_btn_cap.click(fn=rand_image, outputs=im_cap)
#im_cap.change(generate_text, inputs=im_cap, outputs=caption)
caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption)
with gr.Tab("Image Search"):
with gr.Column(variant="panel"):
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
search_btn = gr.Button("Find Images").style(full_width=False)
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False)
demo.launch() |