Update demo/modules/search.py
Browse files- demo/modules/search.py +23 -30
demo/modules/search.py
CHANGED
@@ -13,11 +13,25 @@ tmp_file_path = "/tmp/results.tsv"
|
|
13 |
tmp_plot_path = "/tmp/histogram.svg"
|
14 |
|
15 |
# Samples for input
|
16 |
-
samples =
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Databases for different modalities
|
23 |
now_db = {
|
@@ -60,12 +74,11 @@ def plot(scores) -> None:
|
|
60 |
|
61 |
|
62 |
# Search from database
|
63 |
-
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
|
64 |
input_modality = input_type.replace("sequence", "protein")
|
65 |
with torch.no_grad():
|
66 |
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
67 |
|
68 |
-
db = now_db[query_type]
|
69 |
if query_type == "text":
|
70 |
index = all_index["text"][db][subsection_type]["index"]
|
71 |
ids = all_index["text"][db][subsection_type]["ids"]
|
@@ -139,26 +152,6 @@ def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str,
|
|
139 |
def change_input_type(choice: str):
|
140 |
# Change examples if input type is changed
|
141 |
global samples
|
142 |
-
if choice == "text":
|
143 |
-
samples = [
|
144 |
-
["Proteins with zinc bindings."],
|
145 |
-
["Proteins locating at cell membrane."],
|
146 |
-
["Protein that serves as an enzyme."]
|
147 |
-
]
|
148 |
-
|
149 |
-
elif choice == "sequence":
|
150 |
-
samples = [
|
151 |
-
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
152 |
-
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
153 |
-
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
154 |
-
]
|
155 |
-
|
156 |
-
elif choice == "structure":
|
157 |
-
samples = [
|
158 |
-
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
159 |
-
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
160 |
-
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
161 |
-
]
|
162 |
|
163 |
# Set visibility of upload button
|
164 |
if choice == "text":
|
@@ -166,12 +159,12 @@ def change_input_type(choice: str):
|
|
166 |
else:
|
167 |
visible = True
|
168 |
|
169 |
-
return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
|
170 |
|
171 |
|
172 |
# Load example from dataset
|
173 |
def load_example(example_id):
|
174 |
-
return
|
175 |
|
176 |
|
177 |
# Change the visibility of subsection type
|
@@ -299,7 +292,7 @@ def build_search_module():
|
|
299 |
# Plot the distribution of scores
|
300 |
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
301 |
|
302 |
-
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
|
303 |
outputs=[results, download_btn, histogram])
|
304 |
|
305 |
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
|
|
|
13 |
tmp_plot_path = "/tmp/histogram.svg"
|
14 |
|
15 |
# Samples for input
|
16 |
+
samples = {
|
17 |
+
"sequence": [
|
18 |
+
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
19 |
+
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
20 |
+
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
21 |
+
],
|
22 |
+
|
23 |
+
"structure": [
|
24 |
+
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
25 |
+
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
26 |
+
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
27 |
+
],
|
28 |
+
|
29 |
+
"text": [
|
30 |
+
["Proteins with zinc bindings."],
|
31 |
+
["Proteins locating at cell membrane."],
|
32 |
+
["Protein that serves as an enzyme."]
|
33 |
+
],
|
34 |
+
}
|
35 |
|
36 |
# Databases for different modalities
|
37 |
now_db = {
|
|
|
74 |
|
75 |
|
76 |
# Search from database
|
77 |
+
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
|
78 |
input_modality = input_type.replace("sequence", "protein")
|
79 |
with torch.no_grad():
|
80 |
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
81 |
|
|
|
82 |
if query_type == "text":
|
83 |
index = all_index["text"][db][subsection_type]["index"]
|
84 |
ids = all_index["text"][db][subsection_type]["ids"]
|
|
|
152 |
def change_input_type(choice: str):
|
153 |
# Change examples if input type is changed
|
154 |
global samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Set visibility of upload button
|
157 |
if choice == "text":
|
|
|
159 |
else:
|
160 |
visible = True
|
161 |
|
162 |
+
return gr.update(samples=samples[choice]), "", gr.update(visible=visible), gr.update(visible=visible)
|
163 |
|
164 |
|
165 |
# Load example from dataset
|
166 |
def load_example(example_id):
|
167 |
+
return example_id[0]
|
168 |
|
169 |
|
170 |
# Change the visibility of subsection type
|
|
|
292 |
# Plot the distribution of scores
|
293 |
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
294 |
|
295 |
+
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type, db_type],
|
296 |
outputs=[results, download_btn, histogram])
|
297 |
|
298 |
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
|