EvgenyK commited on
Commit
5f011f5
1 Parent(s): 88248fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -1
app.py CHANGED
@@ -1 +1,54 @@
1
- a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ import numpy as np
4
+ import pandas as pd
5
+ import gradio as gr
6
+ from io import BytesIO
7
+ from PIL import Image as PILIMAGE
8
+ from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
9
+
10
+ #Selecting device based on availability of GPUs
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ #Defining model, processor and tokenizer
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
17
+
18
+
19
+ #Loading the data
20
+ photos = pd.read_csv("./items_data.csv")
21
+ photo_features = np.load("./features.npy")
22
+ photo_ids = pd.read_csv("./photo_ids.csv")
23
+ photo_ids = list(photo_ids['Uniq Id'])
24
+
25
+ def find_best_matches(text):
26
+
27
+ #Inference
28
+ with torch.no_grad():
29
+ # Encode and normalize the description using CLIP
30
+ inputs = tokenizer([text], padding=True, return_tensors="pt")
31
+ inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
32
+ text_encoded = model.get_text_features(**inputs).detach().numpy()
33
+
34
+
35
+ # Finding Cosine similarity
36
+ similarities = list((text_encoded @ photo_features.T).squeeze(0))
37
+
38
+ #Block of code for displaying top 3 best matches (images)
39
+ matched_images = []
40
+ for i in range(3):
41
+ idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[i][1]
42
+ photo_id = photo_ids[idx]
43
+ photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
44
+ response = requests.get(photo_data["photo_image_url"] + "?w=640")
45
+ img = PILIMAGE.open(BytesIO(response.content))
46
+ matched_images.append(img)
47
+ return matched_images
48
+
49
+
50
+ #Gradio app
51
+ iface = gr.Interface(fn=find_best_matches, inputs=[gr.inputs.Textbox(lines=1, label="Text query", placeholder="Introduce the search text...",)],
52
+ theme = "dark",
53
+ outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]),
54
+ enable_queue=True).launch()