#Acknowledgments: #This project is inspired by: #1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov #2. DrishtiSharma/Text-to-Image-search-using-CLIP import torch import requests import numpy as np import pandas as pd import gradio as gr from io import BytesIO from PIL import Image as PILIMAGE from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer #Selecting device based on availability of GPUs device = "cuda" if torch.cuda.is_available() else "cpu" #Defining model, processor and tokenizer model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") #Loading the data photos = pd.read_csv("./items_data.csv") photo_features = np.load("./features.npy") photo_ids = pd.read_csv("./photo_ids.csv") photo_ids = list(photo_ids['photo_id']) def find_best_matches(text): #Inference with torch.no_grad(): # Encode and normalize the description using CLIP inputs = tokenizer([text], padding=True, return_tensors="pt") inputs = processor(text=[text], images=None, return_tensors="pt", padding=True) text_encoded = model.get_text_features(**inputs).detach().numpy() # Finding Cosine similarity similarities = list((text_encoded @ photo_features.T).squeeze(0)) #Block of code for displaying top 3 best matches (images) matched_images = [] for i in range(3): idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[i][1] photo_id = photo_ids[idx] photo_data = photos[photos["Uniq Id"] == photo_id].iloc[0] response = requests.get(photo_data["Image"] + "?w=640") img = PILIMAGE.open(BytesIO(response.content)) matched_images.append(img) return matched_images #Gradio app with gr.Blocks() as demo: with gr.Column(variant="panel"): with gr.Row(variant="compact"): text = gr.Textbox( label="Search product", show_label=False, max_lines=1, placeholder="Type product", ).style( container=False, ) btn = gr.Button("Search").style(full_width=False) gallery = gr.Gallery( label="Products", show_label=False, elem_id="gallery" ).style(grid=[3], height="auto") btn.click(find_best_matches, inputs = text, outputs = gallery) demo.launch(show_api=False)