atwang commited on
Commit
b383c02
·
1 Parent(s): aa175fc

update paths

Browse files
Files changed (1) hide show
  1. app.py +47 -36
app.py CHANGED
@@ -4,40 +4,50 @@ import numpy as np
4
  import h5py
5
  import faiss
6
  from PIL import Image
7
- import io
8
  import pickle
9
  import random
10
 
 
11
  def getRandID():
12
  indx = random.randrange(0, 396503)
13
  return indx_to_id_dict[indx], indx
14
 
 
15
  def chooseImageIndex(indexType):
16
- if (indexType == "FlatIP(default)"):
17
  return image_index_IP
18
- elif (indexType == "FlatL2"):
 
19
  return image_index_L2
20
- elif (indexType == "HNSWFlat"):
 
21
  return image_index_HNSW
22
- elif (indexType == "IVFFlat"):
 
23
  return image_index_IVF
24
- elif (indexType == "LSH"):
 
25
  return image_index_LSH
26
 
 
27
  def chooseDNAIndex(indexType):
28
- if (indexType == "FlatIP(default)"):
29
  return dna_index_IP
30
- elif (indexType == "FlatL2"):
 
31
  return dna_index_L2
32
- elif (indexType == "HNSWFlat"):
 
33
  return dna_index_HNSW
34
- elif (indexType == "IVFFlat"):
 
35
  return dna_index_IVF
36
- elif (indexType == "LSH"):
 
37
  return dna_index_LSH
38
 
39
 
40
-
41
  def searchEmbeddings(id, mod1, mod2, indexType):
42
  # variable and index initialization
43
  dim = 768
@@ -47,16 +57,15 @@ def searchEmbeddings(id, mod1, mod2, indexType):
47
  index = faiss.IndexFlatIP(dim)
48
 
49
  # get index
50
- if (mod2 == "Image"):
51
  index = chooseImageIndex(indexType)
52
- elif (mod2 == "DNA"):
53
  index = chooseDNAIndex(indexType)
54
-
55
 
56
  # search for query
57
- if (mod1 == "Image"):
58
  query = id_to_image_emb_dict[id]
59
- elif (mod1 == "DNA"):
60
  query = id_to_dna_emb_dict[id]
61
  query = query.astype(np.float32)
62
  D, I = index.search(query, num_neighbors)
@@ -66,25 +75,26 @@ def searchEmbeddings(id, mod1, mod2, indexType):
66
  for indx in I[0]:
67
  id = indx_to_id_dict[indx]
68
  id_list.append(id)
69
-
70
  return id_list
71
 
 
72
  with gr.Blocks() as demo:
73
 
74
  # for hf: change all file paths, indx_to_id_dict as well
75
 
76
  # load indexes
77
- image_index_IP = faiss.read_index("big_image_index_FlatIP.index")
78
- image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
79
- image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
80
- image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
81
- image_index_LSH = faiss.read_index("big_image_index_LSH.index")
82
-
83
- dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index")
84
- dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
85
- dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
86
- dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
87
- dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
88
 
89
  with open("dataset_processid_list.pickle", "rb") as f:
90
  dataset_processid_list = pickle.load(f)
@@ -109,14 +119,15 @@ with gr.Blocks() as demo:
109
  mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
110
  mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
111
 
112
- indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)")
 
 
113
  process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
114
- process_id_list = gr.Textbox(label="Closest 10 matches:" )
115
- search_btn = gr.Button("Search")
116
  id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
117
 
118
- search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType],
119
- outputs=[process_id_list])
120
-
121
 
122
- demo.launch()
 
4
  import h5py
5
  import faiss
6
  from PIL import Image
7
+ import io
8
  import pickle
9
  import random
10
 
11
+
12
  def getRandID():
13
  indx = random.randrange(0, 396503)
14
  return indx_to_id_dict[indx], indx
15
 
16
+
17
  def chooseImageIndex(indexType):
18
+ if indexType == "FlatIP(default)":
19
  return image_index_IP
20
+ elif indexType == "FlatL2":
21
+ raise NotImplementedError
22
  return image_index_L2
23
+ elif indexType == "HNSWFlat":
24
+ raise NotImplementedError
25
  return image_index_HNSW
26
+ elif indexType == "IVFFlat":
27
+ raise NotImplementedError
28
  return image_index_IVF
29
+ elif indexType == "LSH":
30
+ raise NotImplementedError
31
  return image_index_LSH
32
 
33
+
34
  def chooseDNAIndex(indexType):
35
+ if indexType == "FlatIP(default)":
36
  return dna_index_IP
37
+ elif indexType == "FlatL2":
38
+ raise NotImplementedError
39
  return dna_index_L2
40
+ elif indexType == "HNSWFlat":
41
+ raise NotImplementedError
42
  return dna_index_HNSW
43
+ elif indexType == "IVFFlat":
44
+ raise NotImplementedError
45
  return dna_index_IVF
46
+ elif indexType == "LSH":
47
+ raise NotImplementedError
48
  return dna_index_LSH
49
 
50
 
 
51
  def searchEmbeddings(id, mod1, mod2, indexType):
52
  # variable and index initialization
53
  dim = 768
 
57
  index = faiss.IndexFlatIP(dim)
58
 
59
  # get index
60
+ if mod2 == "Image":
61
  index = chooseImageIndex(indexType)
62
+ elif mod2 == "DNA":
63
  index = chooseDNAIndex(indexType)
 
64
 
65
  # search for query
66
+ if mod1 == "Image":
67
  query = id_to_image_emb_dict[id]
68
+ elif mod1 == "DNA":
69
  query = id_to_dna_emb_dict[id]
70
  query = query.astype(np.float32)
71
  D, I = index.search(query, num_neighbors)
 
75
  for indx in I[0]:
76
  id = indx_to_id_dict[indx]
77
  id_list.append(id)
78
+
79
  return id_list
80
 
81
+
82
  with gr.Blocks() as demo:
83
 
84
  # for hf: change all file paths, indx_to_id_dict as well
85
 
86
  # load indexes
87
+ image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
88
+ # image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
89
+ # image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
90
+ # image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
91
+ # image_index_LSH = faiss.read_index("big_image_index_LSH.index")
92
+
93
+ dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
94
+ # dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
95
+ # dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
96
+ # dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
97
+ # dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
98
 
99
  with open("dataset_processid_list.pickle", "rb") as f:
100
  dataset_processid_list = pickle.load(f)
 
119
  mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
120
  mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
121
 
122
+ indexType = gr.Radio(
123
+ choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
124
+ )
125
  process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
126
+ process_id_list = gr.Textbox(label="Closest 10 matches:")
127
+ search_btn = gr.Button("Search")
128
  id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
129
 
130
+ search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType], outputs=[process_id_list])
131
+
 
132
 
133
+ demo.launch()