File size: 2,605 Bytes
a6f97a2
 
 
 
 
 
5f011f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8287e38
5f011f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3831fc
 
5f011f5
 
 
 
 
b081923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58ac957
b081923
8daf961
b081923
23db5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#Acknowledgments:
#This project is inspired by:
#1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov
#2. DrishtiSharma/Text-to-Image-search-using-CLIP


import torch
import requests
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

#Selecting device based on availability of GPUs
device = "cuda" if torch.cuda.is_available() else "cpu"
    
#Defining model, processor and tokenizer
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    
#Loading the data
photos = pd.read_csv("./items_data.csv")
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
    
def find_best_matches(text):
    
    #Inference
    with torch.no_grad():
        # Encode and normalize the description using CLIP
        inputs = tokenizer([text],  padding=True, return_tensors="pt")
        inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
    text_encoded =  model.get_text_features(**inputs).detach().numpy()
  
    
    # Finding Cosine similarity
    similarities = list((text_encoded @ photo_features.T).squeeze(0))
    
    #Block of code for displaying top 3 best matches (images)
    matched_images = []
    for i in range(3):
      idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[i][1]
      photo_id = photo_ids[idx]
      photo_data = photos[photos["Uniq Id"] == photo_id].iloc[0]
      response = requests.get(photo_data["Image"] + "?w=640")
      img = PILIMAGE.open(BytesIO(response.content))
      matched_images.append(img)
    return matched_images
    
    
#Gradio app    
with gr.Blocks() as demo:
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            text = gr.Textbox(
                label="Search product",
                show_label=False,
                max_lines=1,
                placeholder="Type product",
            ).style(
                container=False,
            )
            btn = gr.Button("Search").style(full_width=False)

        gallery = gr.Gallery(
            label="Products", show_label=False, elem_id="gallery"
        ).style(grid=[3], height="auto")

    btn.click(find_best_matches, inputs = text, outputs = gallery)

demo.launch(show_api=False)