Spaces:
Build error
Build error
# Import Libraries | |
from pathlib import Path | |
import pandas as pd | |
import numpy as np | |
import torch | |
import pickle | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import gradio as gr | |
import os | |
#from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
import sentence_transformers | |
from sentence_transformers import SentenceTransformer, util | |
# check if CUDA available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the openAI's CLIP model | |
#model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
#processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
#tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
# taking photo IDs | |
#photo_ids = pd.read_csv("./photo_ids.csv") | |
#photo_ids = list(photo_ids['photo_id']) | |
# Photo dataset | |
#photos = pd.read_csv("./photos.tsv000", sep="\t", header=0) | |
# taking features vectors | |
#photo_features = np.load("./features.npy") | |
IMAGES_DIR = Path("./photos/") | |
#def show_output_image(matched_images) : | |
#image=[] | |
#for photo_id in matched_images: | |
# photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280" | |
#response = requests.get(photo_image_url, stream=True) | |
#img = Image.open(BytesIO(response.content)) | |
# response = requests.get(photo_image_url, stream=True).raw | |
# img = Image.open(response) | |
#photo = photo_id + '.jpg' | |
#img = Image.open(response).convert("RGB") | |
#img = Image.open(os.path.join(IMAGES_DIR, photo)) | |
#image.append(img) | |
#return image | |
# Encode and normalize the search query using CLIP | |
#def encode_search_query(search_query, model, device): | |
# with torch.no_grad(): | |
# inputs = tokenizer([search_query], padding=True, return_tensors="pt") | |
#inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True) | |
# text_features = model.get_text_features(**inputs).cpu().numpy() | |
# return text_features | |
# Find all matched photos | |
#def find_matches(features, photo_ids, results_count=4): | |
# Compute the similarity between the search query and each photo using the Cosine similarity | |
#text_features = np.array(text_features) | |
#similarities = (photo_features @ features.T).squeeze(1) | |
# Sort the photos by their similarity score | |
#best_photo_idx = (-similarities).argsort() | |
# Return the photo IDs of the best matches | |
#matches = [photo_ids[i] for i in best_photo_idx[:results_count]] | |
#return matches | |
#Load CLIP model | |
model = SentenceTransformer('clip-ViT-B-32') | |
# pre-computed embeddings | |
emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
with open(emb_filename, 'rb') as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
def display_matches(indices): | |
best_matched_images = [Image.open(os.path.join("photos/", img_names[best_img['corpus_id']])) for best_img in indices] | |
return best_matched_images | |
def image_search(search_text, search_image, option): | |
# Input Text Query | |
#search_query = "The feeling when your program finally works" | |
if option == "Text-To-Image" : | |
# Extracting text features embeddings | |
#text_features = encode_search_query(search_text, model, device) | |
text_emb = model.encode([search_text], convert_to_tensor=True) | |
similarity = util.cos_sim(img_emb, text_emb) | |
return [Image.open(img_folder / img_names[top_k_best_image]) for top_k_best_image in torch.topk(similarity, 2, 0).indices] | |
# Find the matched Images | |
#matched_images = find_matches(text_features, photo_features, photo_ids, 4) | |
#matched_results = util.semantic_search(text_emb, img_emb, top_k=4)[0] | |
# top 4 highest ranked images | |
#return display_matches(matched_results) | |
elif option == "Image-To-Image": | |
# Input Image for Search | |
#search_image = Image.fromarray(search_image.astype('uint8'), 'RGB') | |
#with torch.no_grad(): | |
# processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"] | |
# image_feature = model.get_image_features(processed_image.to(device)) | |
# image_feature /= image_feature.norm(dim=-1, keepdim=True) | |
#image_feature = image_feature.cpu().numpy() | |
# Find the matched Images | |
#matched_images = find_matches(image_feature, photo_ids, 4) | |
#image_emb = model.encode(Image.open(search_image), convert_to_tensor=True) | |
#image_emb = model.encode(Image.open(search_image)) | |
# Find the matched Images | |
#matched_images = find_matches(text_features, photo_features, photo_ids, 4) | |
#similarity = util.cos_sim(image_emb, img_emb) | |
#matched_results = util.semantic_search(image_emb, img_emb, 4)[0] | |
emb = model.encode([Image.fromarray(image)], convert_to_tensor=True) | |
similarity = util.cos_sim(img_emb, emb) | |
return [Image.open(IMAGES_DIR / img_names[top_k_best_image]) for top_k_best_image in torch.topk(similarity, 2, 0).indices] | |
gr.Interface(fn=image_search, | |
inputs=[gr.inputs.Textbox(lines=7, label="Input Text"), | |
gr.inputs.Image(type="pil", optional=True), | |
gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"]) | |
], | |
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]), | |
enable_queue=True | |
).launch(debug=True,share=True) |