pleonard commited on
Commit
abff508
·
verified ·
1 Parent(s): b81b56e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import torch
5
+ import clip
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
10
+ model, preprocess = clip.load("ViT-B/32", device=device)
11
+ print('Using ' + device)
12
+
13
+ features_path = 'features/'
14
+
15
+
16
+
17
+ photo_features = np.load(features_path + "features.npy")
18
+ photo_ids = pd.read_csv(features_path+ "updated_file.csv")
19
+ descriptions = list(photo_ids['description'])
20
+ photo_filenames = list(photo_ids['photo_id'])
21
+
22
+
23
+
24
+ def clip_search(search_string):
25
+
26
+ with torch.no_grad():
27
+ # Encode and normalize the description using CLIP
28
+ text_encoded = model.encode_text(clip.tokenize(search_string).to(device))
29
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
30
+ # Retrieve the description vector and the photo vectors
31
+ text_features = text_encoded.cpu().numpy()
32
+
33
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
34
+ similarities = list((text_features @ photo_features.T).squeeze(0))
35
+
36
+ # Sort the photos by their similarity score
37
+ candidates = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
38
+
39
+ images = []
40
+ for i in range(30):
41
+ # Retrieve the photo ID
42
+ idx = candidates[i][1]
43
+ photo_id = photo_filenames[idx]
44
+ caption = descriptions[idx]
45
+
46
+ images.append([('images/' + str(photo_id)), caption])
47
+ return images
48
+
49
+ css = "footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
50
+ with gr.Blocks(css = css) as demo:
51
+ with gr.Column(variant="panel"):
52
+ with gr.Row(variant="compact"):
53
+ search_string = gr.Textbox(
54
+ label="Evocative Search",
55
+ show_label=True,
56
+ max_lines=1,
57
+ placeholder="Type something abstruse, or click a suggested search below.",
58
+ ).style(
59
+ container=False,
60
+ )
61
+ btn = gr.Button("Retrieve Images", variant="primary").style(full_width=False)
62
+ with gr.Row(variant="compact"):
63
+ suggest1 = gr.Button("rococo", variant="secondary").style(size="sm")
64
+ suggest2 = gr.Button("brutalism", variant="secondary").style(size="sm")
65
+ suggest3 = gr.Button("classical", variant="secondary").style(size="sm")
66
+ suggest4 = gr.Button("gothic", variant="secondary").style(size="sm")
67
+ suggest5 = gr.Button("foliate", variant="secondary").style(size="sm")
68
+ gallery = gr.Gallery(
69
+ label=False, show_label=False, elem_id="gallery"
70
+ ).style(grid=[6], height="100%",)
71
+
72
+ suggest1.click(clip_search, inputs=suggest1, outputs=gallery)
73
+ suggest2.click(clip_search, inputs=suggest2, outputs=gallery)
74
+ suggest3.click(clip_search, inputs=suggest3, outputs=gallery)
75
+ suggest4.click(clip_search, inputs=suggest4, outputs=gallery)
76
+ suggest5.click(clip_search, inputs=suggest5, outputs=gallery)
77
+ btn.click(clip_search, inputs=search_string, outputs=gallery)
78
+ search_string.submit(clip_search, search_string, gallery)
79
+
80
+
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch(share=False, server_name='0.0.0.0')
84
+ demo.close()
85
+ # demo.launch()