Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -8,183 +8,172 @@ import io
|
|
8 |
import pickle
|
9 |
import random
|
10 |
|
11 |
-
def
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
# if (mod1 == "Image"):
|
47 |
-
# query = id_to_image_emb_dict[id]
|
48 |
-
# elif(mod1 == "DNA"):
|
49 |
-
# query = id_to_dna_emb_dict[id]
|
50 |
-
# query = query.astype(np.float32)
|
51 |
-
# D, I = index.search(query, num_neighbors)
|
52 |
-
|
53 |
-
# id_list = []
|
54 |
-
# for indx in I[0]:
|
55 |
-
# id = indx_to_id_dict[indx]
|
56 |
-
# id_list.append(id)
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
#
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
8 |
import pickle
|
9 |
import random
|
10 |
|
11 |
+
def get_image(image1, image2, dataset_image_mask, processid_to_index, idx):
|
12 |
+
if (idx < 162834):
|
13 |
+
image_enc_padded = image1[idx].astype(np.uint8)
|
14 |
+
elif(idx >= 162834):
|
15 |
+
image_enc_padded = image2[idx-162834].astype(np.uint8)
|
16 |
+
enc_length = dataset_image_mask[idx]
|
17 |
+
image_enc = image_enc_padded[:enc_length]
|
18 |
+
image = Image.open(io.BytesIO(image_enc))
|
19 |
+
return image
|
20 |
+
|
21 |
+
def searchEmbeddings(id, mod1, mod2):
|
22 |
+
# variable and index initialization
|
23 |
+
original_indx = processid_to_index[id]
|
24 |
+
dim = 768
|
25 |
+
num_neighbors = 10
|
26 |
+
|
27 |
+
# get index
|
28 |
+
index = faiss.IndexFlatIP(dim)
|
29 |
+
if (mod2 == "Image"):
|
30 |
+
index = faiss.read_index("image_index.index")
|
31 |
+
elif (mod2 == "DNA"):
|
32 |
+
index = faiss.read_index("dna_index.index")
|
33 |
+
|
34 |
+
# search index
|
35 |
+
if (mod1 == "Image"):
|
36 |
+
query = id_to_image_emb_dict[id]
|
37 |
+
elif(mod1 == "DNA"):
|
38 |
+
query = id_to_dna_emb_dict[id]
|
39 |
+
query = query.astype(np.float32)
|
40 |
+
D, I = index.search(query, num_neighbors)
|
41 |
+
|
42 |
+
id_list = []
|
43 |
+
for indx in I[0]:
|
44 |
+
id = indx_to_id_dict[indx]
|
45 |
+
id_list.append(id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
# get images
|
48 |
+
image0 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, original_indx)
|
49 |
+
image1 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][0])
|
50 |
+
image2 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][1])
|
51 |
+
image3 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][2])
|
52 |
+
image4 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][3])
|
53 |
+
image5 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][4])
|
54 |
+
image6 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][5])
|
55 |
+
image7 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][6])
|
56 |
+
image8 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][7])
|
57 |
+
image9 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][8])
|
58 |
+
image10 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][9])
|
59 |
+
|
60 |
+
# get taxonomic information
|
61 |
+
s0 = getTax(original_indx)
|
62 |
+
s1 = getTax(I[0][0])
|
63 |
+
s2 = getTax(I[0][1])
|
64 |
+
s3 = getTax(I[0][2])
|
65 |
+
s4 = getTax(I[0][3])
|
66 |
+
s5 = getTax(I[0][4])
|
67 |
+
s6 = getTax(I[0][5])
|
68 |
+
s7 = getTax(I[0][6])
|
69 |
+
s8 = getTax(I[0][7])
|
70 |
+
s9 = getTax(I[0][8])
|
71 |
+
s10 = getTax(I[0][9])
|
72 |
|
73 |
+
return id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10
|
74 |
+
|
75 |
+
def getRandID():
|
76 |
+
indx = random.randrange(0, 325667)
|
77 |
+
return indx_to_id_dict[indx], indx
|
78 |
+
|
79 |
+
def getTax(indx):
|
80 |
+
s = species[indx]
|
81 |
+
g = genus[indx]
|
82 |
+
f = family[indx]
|
83 |
+
str = "Species: " + s + "\nGenus: " + g + "\nFamily: " + f
|
84 |
+
return str
|
85 |
+
|
86 |
+
with gr.Blocks(title="Bioscan-Clip") as demo:
|
87 |
+
# open general files
|
88 |
+
with open("dataset_image1.pickle", "rb") as f:
|
89 |
+
dataset_image1 = pickle.load(f)
|
90 |
+
with open("dataset_image2.pickle", "rb") as f:
|
91 |
+
dataset_image2 = pickle.load(f)
|
92 |
+
with open("dataset_processid_list.pickle", "rb") as f:
|
93 |
+
dataset_processid_list = pickle.load(f)
|
94 |
+
with open("dataset_image_mask.pickle", "rb") as f:
|
95 |
+
dataset_image_mask = pickle.load(f)
|
96 |
+
with open("processid_to_index.pickle", "rb") as f:
|
97 |
+
processid_to_index = pickle.load(f)
|
98 |
+
with open("indx_to_id_dict.pickle", "rb") as f:
|
99 |
+
indx_to_id_dict = pickle.load(f)
|
100 |
+
|
101 |
+
# open image files
|
102 |
+
with open("id_to_image_emb_dict.pickle", "rb") as f:
|
103 |
+
id_to_image_emb_dict = pickle.load(f)
|
104 |
+
|
105 |
+
# open dna files
|
106 |
+
with open("id_to_dna_emb_dict.pickle", "rb") as f:
|
107 |
+
id_to_dna_emb_dict = pickle.load(f)
|
108 |
+
|
109 |
+
# open taxonomy files
|
110 |
+
with open("family.pickle", "rb") as f:
|
111 |
+
family = [item.decode("utf-8") for item in pickle.load(f)]
|
112 |
+
with open("genus.pickle", "rb") as f:
|
113 |
+
genus= [item.decode("utf-8") for item in pickle.load(f)]
|
114 |
+
with open("species.pickle", "rb") as f:
|
115 |
+
species = [item.decode("utf-8") for item in pickle.load(f)]
|
116 |
+
with gr.Column():
|
117 |
+
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
118 |
+
process_id_list = gr.Textbox(label="Closest 10 matches:" )
|
119 |
+
mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
|
120 |
+
mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
|
121 |
+
search_btn = gr.Button("Search")
|
122 |
+
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
image0 = gr.Image(label="Original", height=550)
|
126 |
+
tax0 = gr.Textbox(label="Taxonomy")
|
127 |
+
with gr.Column():
|
128 |
+
rand_id = gr.Textbox(label="Random ID:")
|
129 |
+
rand_id_indx = gr.Textbox(label="Index:")
|
130 |
+
id_btn = gr.Button("Get Random ID")
|
131 |
+
|
132 |
+
with gr.Row():
|
133 |
+
with gr.Column():
|
134 |
+
image1 = gr.Image(label=1)
|
135 |
+
tax1 = gr.Textbox(label="Taxonomy")
|
136 |
+
with gr.Column():
|
137 |
+
image2 = gr.Image(label=2)
|
138 |
+
tax2 = gr.Textbox(label="Taxonomy")
|
139 |
+
with gr.Column():
|
140 |
+
image3 = gr.Image(label=3)
|
141 |
+
tax3 = gr.Textbox(label="Taxonomy")
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column():
|
145 |
+
image4 = gr.Image(label=4)
|
146 |
+
tax4 = gr.Textbox(label="Taxonomy")
|
147 |
+
with gr.Column():
|
148 |
+
image5 = gr.Image(label=5)
|
149 |
+
tax5 = gr.Textbox(label="Taxonomy")
|
150 |
+
with gr.Column():
|
151 |
+
image6 = gr.Image(label=6)
|
152 |
+
tax6 = gr.Textbox(label="Taxonomy")
|
153 |
+
|
154 |
+
with gr.Row():
|
155 |
+
with gr.Column():
|
156 |
+
image7 = gr.Image(label=7)
|
157 |
+
tax7 = gr.Textbox(label="Taxonomy")
|
158 |
+
with gr.Column():
|
159 |
+
image8 = gr.Image(label=8)
|
160 |
+
tax8 = gr.Textbox(label="Taxonomy")
|
161 |
+
with gr.Column():
|
162 |
+
image9 = gr.Image(label=9)
|
163 |
+
tax9 = gr.Textbox(label="Taxonomy")
|
164 |
+
with gr.Column():
|
165 |
+
image10 = gr.Image(label=10)
|
166 |
+
tax10 = gr.Textbox(label="Taxonomy")
|
167 |
+
|
168 |
+
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
169 |
+
search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2],
|
170 |
+
outputs=[process_id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10,
|
171 |
+
tax0, tax1, tax2, tax3, tax4, tax5, tax6, tax7, tax8, tax9, tax10])
|
172 |
+
examples = gr.Examples(
|
173 |
+
examples=[["ABOTH966-22", "DNA", "DNA"],
|
174 |
+
["CRTOB8472-22", "DNA", "Image"],
|
175 |
+
["PLOAD050-20", "Image", "DNA"],
|
176 |
+
["HELAC26711-21", "Image", "Image"]],
|
177 |
+
inputs=[process_id, mod1, mod2],)
|
178 |
+
|
179 |
+
demo.launch()
|