Spaces:
Build error
Build error
# Import Libraries | |
from pathlib import Path | |
import pandas as pd | |
import numpy as np | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import gradio as gr | |
import os | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
# check if CUDA available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the openAI's CLIP model | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
# taking photo IDs | |
photo_ids = pd.read_csv("./photo_ids.csv") | |
photo_ids = list(photo_ids['photo_id']) | |
# Photo dataset | |
photos = pd.read_csv("./photos.tsv000", sep="\t", header=0) | |
# taking features vectors | |
photo_features = np.load("./features.npy") | |
IMAGES_DIR = './photos' | |
def show_output_image(matched_images) : | |
image=[] | |
for photo_id in matched_images: | |
photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280" | |
#response = requests.get(photo_image_url, stream=True) | |
#img = Image.open(BytesIO(response.content)) | |
#response = requests.get(photo_image_url, stream=True).raw | |
photo = photo_id + '.jpg' | |
#img = Image.open(response).convert("RGB") | |
img = Image.open(os.path.join(IMAGES_DIR, photo)) | |
image.append(img) | |
return image | |
# Encode and normalize the search query using CLIP | |
def encode_search_query(search_query, model, device): | |
with torch.no_grad(): | |
inputs = tokenizer([search_query], padding=True, return_tensors="pt") | |
#inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True) | |
text_features = model.get_text_features(**inputs).detach().numpy() | |
return text_features | |
# Find all matched photos | |
def find_matches(text_features, photo_features, photo_ids, results_count=4): | |
# Compute the similarity between the search query and each photo using the Cosine similarity | |
text_features = np.array(text_features) | |
similarities = (photo_features @ text_features.T).squeeze(1) | |
# Sort the photos by their similarity score | |
best_photo_idx = (-similarities).argsort() | |
# Return the photo IDs of the best matches | |
matches = [photo_ids[i] for i in best_photo_idx[:results_count]] | |
return matches | |
def image_search(search_text, search_image, option): | |
# Input Text Query | |
#search_query = "The feeling when your program finally works" | |
if option == "Text-To-Image" : | |
# Extracting text features | |
text_features = encode_search_query(search_text, model, device) | |
# Find the matched Images | |
matched_images = find_matches(text_features, photo_features, photo_ids, 4) | |
return show_output_image(matched_images) | |
elif option == "Image-To-Image": | |
# Input Image for Search | |
search_image = Image.fromarray(search_image.astype('uint8'), 'RGB') | |
with torch.no_grad(): | |
processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"] | |
image_feature = model.get_image_features(processed_image.to(device)) | |
image_feature /= image_feature.norm(dim=-1, keepdim=True) | |
image_feature = image_feature.detach().numpy() | |
# Find the matched Images | |
matched_images = find_matches(image_feature, photo_features, photo_ids, 4) | |
return show_output_image(matched_images) | |
gr.Interface(fn=image_search, | |
inputs=[gr.inputs.Textbox(lines=7, label="Input Text"), | |
gr.inputs.Image(type="pil", optional=True), | |
gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"]) | |
], | |
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]), | |
enable_queue=True | |
).launch(debug=True,share=True) |