Update demo/modules/search.py
Browse files- demo/modules/search.py +303 -303
demo/modules/search.py
CHANGED
@@ -1,304 +1,304 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
import pandas as pd
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
from scipy.stats import norm
|
8 |
-
from .init_model import model, all_index, valid_subsections
|
9 |
-
from .blocks import upload_pdb_button, parse_pdb_file
|
10 |
-
|
11 |
-
|
12 |
-
tmp_file_path = "/tmp/results.tsv"
|
13 |
-
tmp_plot_path = "/tmp/histogram.svg"
|
14 |
-
|
15 |
-
# Samples for input
|
16 |
-
samples = [
|
17 |
-
["Proteins with zinc bindings."],
|
18 |
-
["Proteins locating at cell membrane."],
|
19 |
-
["Protein that serves as an enzyme."]
|
20 |
-
]
|
21 |
-
|
22 |
-
# Databases for different modalities
|
23 |
-
now_db = {
|
24 |
-
"sequence": list(all_index["sequence"].keys())[0],
|
25 |
-
"structure": list(all_index["structure"].keys())[0],
|
26 |
-
"text": list(all_index["text"].keys())[0]
|
27 |
-
}
|
28 |
-
|
29 |
-
|
30 |
-
def clear_results():
|
31 |
-
return "", gr.update(visible=False), gr.update(visible=False)
|
32 |
-
|
33 |
-
|
34 |
-
def plot(scores) -> None:
|
35 |
-
"""
|
36 |
-
Plot the distribution of scores and fit a normal distribution.
|
37 |
-
Args:
|
38 |
-
scores: List of scores
|
39 |
-
"""
|
40 |
-
plt.hist(scores, bins=100, density=True, alpha=0.6)
|
41 |
-
plt.title('Distribution of similarity scores in the database', fontsize=15)
|
42 |
-
plt.xlabel('Similarity score', fontsize=15)
|
43 |
-
plt.ylabel('Density', fontsize=15)
|
44 |
-
|
45 |
-
mu, std = norm.fit(scores)
|
46 |
-
|
47 |
-
# Plot the Gaussian
|
48 |
-
xmin, xmax = plt.xlim()
|
49 |
-
_, ymax = plt.ylim()
|
50 |
-
x = np.linspace(xmin, xmax, 100)
|
51 |
-
p = norm.pdf(x, mu, std)
|
52 |
-
plt.plot(x, p)
|
53 |
-
|
54 |
-
# Plot total number of scores
|
55 |
-
plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
|
56 |
-
|
57 |
-
# Convert the plot to svg format
|
58 |
-
plt.savefig(tmp_plot_path)
|
59 |
-
plt.cla()
|
60 |
-
|
61 |
-
|
62 |
-
# Search from database
|
63 |
-
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
|
64 |
-
input_modality = input_type.replace("sequence", "protein")
|
65 |
-
with torch.no_grad():
|
66 |
-
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
67 |
-
|
68 |
-
db = now_db[query_type]
|
69 |
-
if query_type == "text":
|
70 |
-
index = all_index["text"][db][subsection_type]["index"]
|
71 |
-
ids = all_index["text"][db][subsection_type]["ids"]
|
72 |
-
|
73 |
-
else:
|
74 |
-
index = all_index[query_type][db]["index"]
|
75 |
-
ids = all_index[query_type][db]["ids"]
|
76 |
-
|
77 |
-
if check_index_ivf(query_type, subsection_type):
|
78 |
-
if index.nlist < nprobe:
|
79 |
-
raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
|
80 |
-
else:
|
81 |
-
index.nprobe = nprobe
|
82 |
-
|
83 |
-
if topk > index.ntotal:
|
84 |
-
raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
|
85 |
-
|
86 |
-
# Retrieve all scores to plot the distribution
|
87 |
-
scores, ranks = index.search(input_embedding, index.ntotal)
|
88 |
-
scores, ranks = scores[0], ranks[0]
|
89 |
-
|
90 |
-
# Remove inf values
|
91 |
-
selector = scores > -1
|
92 |
-
scores = scores[selector]
|
93 |
-
ranks = ranks[selector]
|
94 |
-
scores = scores / model.temperature.item()
|
95 |
-
plot(scores)
|
96 |
-
|
97 |
-
top_scores = scores[:topk]
|
98 |
-
top_ranks = ranks[:topk]
|
99 |
-
|
100 |
-
# ranks = [list(range(topk))]
|
101 |
-
# ids = ["P12345"] * topk
|
102 |
-
# scores = torch.randn(topk).tolist()
|
103 |
-
|
104 |
-
# Write the results to a temporary file for downloading
|
105 |
-
with open(tmp_file_path, "w") as w:
|
106 |
-
w.write("Id\tMatching score\n")
|
107 |
-
for i in range(topk):
|
108 |
-
rank = top_ranks[i]
|
109 |
-
w.write(f"{ids[rank]}\t{top_scores[i]}\n")
|
110 |
-
|
111 |
-
# Get topk ids
|
112 |
-
topk_ids = []
|
113 |
-
for rank in top_ranks:
|
114 |
-
now_id = ids[rank]
|
115 |
-
if query_type == "text":
|
116 |
-
topk_ids.append(now_id)
|
117 |
-
else:
|
118 |
-
if db != "PDB":
|
119 |
-
# Provide link to uniprot website
|
120 |
-
topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
|
121 |
-
else:
|
122 |
-
# Provide link to pdb website
|
123 |
-
pdb_id = now_id.split("-")[0]
|
124 |
-
topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
|
125 |
-
|
126 |
-
limit = 1000
|
127 |
-
df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
|
128 |
-
if len(topk_ids) > limit:
|
129 |
-
info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
|
130 |
-
index=[1000])
|
131 |
-
df = pd.concat([df, info_df], axis=0)
|
132 |
-
|
133 |
-
output = df.to_markdown()
|
134 |
-
return (output,
|
135 |
-
gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
|
136 |
-
gr.update(value=tmp_plot_path, visible=True))
|
137 |
-
|
138 |
-
|
139 |
-
def change_input_type(choice: str):
|
140 |
-
# Change examples if input type is changed
|
141 |
-
global samples
|
142 |
-
if choice == "text":
|
143 |
-
samples = [
|
144 |
-
["Proteins with zinc bindings."],
|
145 |
-
["Proteins locating at cell membrane."],
|
146 |
-
["Protein that serves as an enzyme."]
|
147 |
-
]
|
148 |
-
|
149 |
-
elif choice == "sequence":
|
150 |
-
samples = [
|
151 |
-
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
152 |
-
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
153 |
-
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
154 |
-
]
|
155 |
-
|
156 |
-
elif choice == "structure":
|
157 |
-
samples = [
|
158 |
-
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
159 |
-
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
160 |
-
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
161 |
-
]
|
162 |
-
|
163 |
-
# Set visibility of upload button
|
164 |
-
if choice == "text":
|
165 |
-
visible = False
|
166 |
-
else:
|
167 |
-
visible = True
|
168 |
-
|
169 |
-
return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
|
170 |
-
|
171 |
-
|
172 |
-
# Load example from dataset
|
173 |
-
def load_example(example_id):
|
174 |
-
return samples[example_id][0]
|
175 |
-
|
176 |
-
|
177 |
-
# Change the visibility of subsection type
|
178 |
-
def change_output_type(query_type: str, subsection_type: str):
|
179 |
-
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
180 |
-
subsection_visible = True if query_type == "text" else False
|
181 |
-
|
182 |
-
return (
|
183 |
-
gr.update(visible=subsection_visible),
|
184 |
-
gr.update(visible=nprobe_visible),
|
185 |
-
gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
|
186 |
-
)
|
187 |
-
|
188 |
-
|
189 |
-
def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
190 |
-
"""
|
191 |
-
Check if the index is of IVF type.
|
192 |
-
Args:
|
193 |
-
index_type: Type of index.
|
194 |
-
subsection_type: If the "index_type" is "text", get the index based on the subsection type.
|
195 |
-
|
196 |
-
Returns:
|
197 |
-
Whether the index is of IVF type or not.
|
198 |
-
"""
|
199 |
-
db = now_db[index_type]
|
200 |
-
if index_type == "sequence":
|
201 |
-
index = all_index["sequence"][db]["index"]
|
202 |
-
|
203 |
-
elif index_type == "structure":
|
204 |
-
index = all_index["structure"][db]["index"]
|
205 |
-
|
206 |
-
elif index_type == "text":
|
207 |
-
index = all_index["text"][db][subsection_type]["index"]
|
208 |
-
|
209 |
-
nprobe_visible = True if hasattr(index, "nprobe") else False
|
210 |
-
return nprobe_visible
|
211 |
-
|
212 |
-
|
213 |
-
def change_db_type(query_type: str, subsection_type: str, db_type: str):
|
214 |
-
"""
|
215 |
-
Change the database to search.
|
216 |
-
Args:
|
217 |
-
query_type: The output type.
|
218 |
-
db_type: The database to search.
|
219 |
-
"""
|
220 |
-
now_db[query_type] = db_type
|
221 |
-
|
222 |
-
if query_type == "text":
|
223 |
-
subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
|
224 |
-
else:
|
225 |
-
subsection_update = gr.update(visible=False)
|
226 |
-
|
227 |
-
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
228 |
-
return subsection_update, gr.update(visible=nprobe_visible)
|
229 |
-
|
230 |
-
|
231 |
-
# Build the searching block
|
232 |
-
def build_search_module():
|
233 |
-
gr.Markdown(f"# Search from
|
234 |
-
with gr.Row(equal_height=True):
|
235 |
-
with gr.Column():
|
236 |
-
# Set input type
|
237 |
-
input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
|
238 |
-
|
239 |
-
with gr.Row():
|
240 |
-
# Set output type
|
241 |
-
query_type = gr.Radio(
|
242 |
-
["sequence", "structure", "text"],
|
243 |
-
label="Output type (e.g. 'sequence' means returning qualified sequences)",
|
244 |
-
value="sequence",
|
245 |
-
scale=2,
|
246 |
-
)
|
247 |
-
|
248 |
-
# If the output type is "text", provide an option to choose the subsection of text
|
249 |
-
subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
|
250 |
-
interactive=True, visible=False, scale=0)
|
251 |
-
|
252 |
-
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
|
253 |
-
interactive=True, visible=True, scale=0)
|
254 |
-
|
255 |
-
with gr.Row():
|
256 |
-
# Input box
|
257 |
-
input = gr.Text(label="Input")
|
258 |
-
|
259 |
-
# Provide an upload button to upload a pdb file
|
260 |
-
upload_btn, chain_box = upload_pdb_button(visible=False)
|
261 |
-
upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
|
262 |
-
|
263 |
-
|
264 |
-
# If the index is of IVF type, provide an option to choose the number of clusters.
|
265 |
-
nprobe_visible = check_index_ivf(query_type.value)
|
266 |
-
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
267 |
-
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
268 |
-
|
269 |
-
# Add event listener to output type
|
270 |
-
query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
|
271 |
-
outputs=[subsection_type, nprobe, db_type])
|
272 |
-
|
273 |
-
# Add event listener to db type
|
274 |
-
db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
|
275 |
-
outputs=[subsection_type, nprobe])
|
276 |
-
|
277 |
-
# Choose topk results
|
278 |
-
topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
|
279 |
-
|
280 |
-
# Provide examples
|
281 |
-
examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
|
282 |
-
|
283 |
-
# Add click event to examples
|
284 |
-
examples.click(fn=load_example, inputs=[examples], outputs=input)
|
285 |
-
|
286 |
-
# Change examples based on input type
|
287 |
-
input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
|
288 |
-
|
289 |
-
with gr.Row():
|
290 |
-
search_btn = gr.Button(value="Search")
|
291 |
-
clear_btn = gr.Button(value="Clear")
|
292 |
-
|
293 |
-
with gr.Row():
|
294 |
-
with gr.Column():
|
295 |
-
results = gr.Markdown(label="results", height=450)
|
296 |
-
download_btn = gr.DownloadButton(label="Download results", visible=False)
|
297 |
-
|
298 |
-
# Plot the distribution of scores
|
299 |
-
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
300 |
-
|
301 |
-
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
|
302 |
-
outputs=[results, download_btn, histogram])
|
303 |
-
|
304 |
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from scipy.stats import norm
|
8 |
+
from .init_model import model, all_index, valid_subsections
|
9 |
+
from .blocks import upload_pdb_button, parse_pdb_file
|
10 |
+
|
11 |
+
|
12 |
+
tmp_file_path = "/tmp/results.tsv"
|
13 |
+
tmp_plot_path = "/tmp/histogram.svg"
|
14 |
+
|
15 |
+
# Samples for input
|
16 |
+
samples = [
|
17 |
+
["Proteins with zinc bindings."],
|
18 |
+
["Proteins locating at cell membrane."],
|
19 |
+
["Protein that serves as an enzyme."]
|
20 |
+
]
|
21 |
+
|
22 |
+
# Databases for different modalities
|
23 |
+
now_db = {
|
24 |
+
"sequence": list(all_index["sequence"].keys())[0],
|
25 |
+
"structure": list(all_index["structure"].keys())[0],
|
26 |
+
"text": list(all_index["text"].keys())[0]
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def clear_results():
|
31 |
+
return "", gr.update(visible=False), gr.update(visible=False)
|
32 |
+
|
33 |
+
|
34 |
+
def plot(scores) -> None:
|
35 |
+
"""
|
36 |
+
Plot the distribution of scores and fit a normal distribution.
|
37 |
+
Args:
|
38 |
+
scores: List of scores
|
39 |
+
"""
|
40 |
+
plt.hist(scores, bins=100, density=True, alpha=0.6)
|
41 |
+
plt.title('Distribution of similarity scores in the database', fontsize=15)
|
42 |
+
plt.xlabel('Similarity score', fontsize=15)
|
43 |
+
plt.ylabel('Density', fontsize=15)
|
44 |
+
|
45 |
+
mu, std = norm.fit(scores)
|
46 |
+
|
47 |
+
# Plot the Gaussian
|
48 |
+
xmin, xmax = plt.xlim()
|
49 |
+
_, ymax = plt.ylim()
|
50 |
+
x = np.linspace(xmin, xmax, 100)
|
51 |
+
p = norm.pdf(x, mu, std)
|
52 |
+
plt.plot(x, p)
|
53 |
+
|
54 |
+
# Plot total number of scores
|
55 |
+
plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
|
56 |
+
|
57 |
+
# Convert the plot to svg format
|
58 |
+
plt.savefig(tmp_plot_path)
|
59 |
+
plt.cla()
|
60 |
+
|
61 |
+
|
62 |
+
# Search from database
|
63 |
+
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
|
64 |
+
input_modality = input_type.replace("sequence", "protein")
|
65 |
+
with torch.no_grad():
|
66 |
+
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
67 |
+
|
68 |
+
db = now_db[query_type]
|
69 |
+
if query_type == "text":
|
70 |
+
index = all_index["text"][db][subsection_type]["index"]
|
71 |
+
ids = all_index["text"][db][subsection_type]["ids"]
|
72 |
+
|
73 |
+
else:
|
74 |
+
index = all_index[query_type][db]["index"]
|
75 |
+
ids = all_index[query_type][db]["ids"]
|
76 |
+
|
77 |
+
if check_index_ivf(query_type, subsection_type):
|
78 |
+
if index.nlist < nprobe:
|
79 |
+
raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
|
80 |
+
else:
|
81 |
+
index.nprobe = nprobe
|
82 |
+
|
83 |
+
if topk > index.ntotal:
|
84 |
+
raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
|
85 |
+
|
86 |
+
# Retrieve all scores to plot the distribution
|
87 |
+
scores, ranks = index.search(input_embedding, index.ntotal)
|
88 |
+
scores, ranks = scores[0], ranks[0]
|
89 |
+
|
90 |
+
# Remove inf values
|
91 |
+
selector = scores > -1
|
92 |
+
scores = scores[selector]
|
93 |
+
ranks = ranks[selector]
|
94 |
+
scores = scores / model.temperature.item()
|
95 |
+
plot(scores)
|
96 |
+
|
97 |
+
top_scores = scores[:topk]
|
98 |
+
top_ranks = ranks[:topk]
|
99 |
+
|
100 |
+
# ranks = [list(range(topk))]
|
101 |
+
# ids = ["P12345"] * topk
|
102 |
+
# scores = torch.randn(topk).tolist()
|
103 |
+
|
104 |
+
# Write the results to a temporary file for downloading
|
105 |
+
with open(tmp_file_path, "w") as w:
|
106 |
+
w.write("Id\tMatching score\n")
|
107 |
+
for i in range(topk):
|
108 |
+
rank = top_ranks[i]
|
109 |
+
w.write(f"{ids[rank]}\t{top_scores[i]}\n")
|
110 |
+
|
111 |
+
# Get topk ids
|
112 |
+
topk_ids = []
|
113 |
+
for rank in top_ranks:
|
114 |
+
now_id = ids[rank]
|
115 |
+
if query_type == "text":
|
116 |
+
topk_ids.append(now_id)
|
117 |
+
else:
|
118 |
+
if db != "PDB":
|
119 |
+
# Provide link to uniprot website
|
120 |
+
topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
|
121 |
+
else:
|
122 |
+
# Provide link to pdb website
|
123 |
+
pdb_id = now_id.split("-")[0]
|
124 |
+
topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
|
125 |
+
|
126 |
+
limit = 1000
|
127 |
+
df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
|
128 |
+
if len(topk_ids) > limit:
|
129 |
+
info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
|
130 |
+
index=[1000])
|
131 |
+
df = pd.concat([df, info_df], axis=0)
|
132 |
+
|
133 |
+
output = df.to_markdown()
|
134 |
+
return (output,
|
135 |
+
gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
|
136 |
+
gr.update(value=tmp_plot_path, visible=True))
|
137 |
+
|
138 |
+
|
139 |
+
def change_input_type(choice: str):
|
140 |
+
# Change examples if input type is changed
|
141 |
+
global samples
|
142 |
+
if choice == "text":
|
143 |
+
samples = [
|
144 |
+
["Proteins with zinc bindings."],
|
145 |
+
["Proteins locating at cell membrane."],
|
146 |
+
["Protein that serves as an enzyme."]
|
147 |
+
]
|
148 |
+
|
149 |
+
elif choice == "sequence":
|
150 |
+
samples = [
|
151 |
+
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
152 |
+
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
153 |
+
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
154 |
+
]
|
155 |
+
|
156 |
+
elif choice == "structure":
|
157 |
+
samples = [
|
158 |
+
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
159 |
+
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
160 |
+
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
161 |
+
]
|
162 |
+
|
163 |
+
# Set visibility of upload button
|
164 |
+
if choice == "text":
|
165 |
+
visible = False
|
166 |
+
else:
|
167 |
+
visible = True
|
168 |
+
|
169 |
+
return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
|
170 |
+
|
171 |
+
|
172 |
+
# Load example from dataset
|
173 |
+
def load_example(example_id):
|
174 |
+
return samples[example_id][0]
|
175 |
+
|
176 |
+
|
177 |
+
# Change the visibility of subsection type
|
178 |
+
def change_output_type(query_type: str, subsection_type: str):
|
179 |
+
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
180 |
+
subsection_visible = True if query_type == "text" else False
|
181 |
+
|
182 |
+
return (
|
183 |
+
gr.update(visible=subsection_visible),
|
184 |
+
gr.update(visible=nprobe_visible),
|
185 |
+
gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
190 |
+
"""
|
191 |
+
Check if the index is of IVF type.
|
192 |
+
Args:
|
193 |
+
index_type: Type of index.
|
194 |
+
subsection_type: If the "index_type" is "text", get the index based on the subsection type.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
Whether the index is of IVF type or not.
|
198 |
+
"""
|
199 |
+
db = now_db[index_type]
|
200 |
+
if index_type == "sequence":
|
201 |
+
index = all_index["sequence"][db]["index"]
|
202 |
+
|
203 |
+
elif index_type == "structure":
|
204 |
+
index = all_index["structure"][db]["index"]
|
205 |
+
|
206 |
+
elif index_type == "text":
|
207 |
+
index = all_index["text"][db][subsection_type]["index"]
|
208 |
+
|
209 |
+
nprobe_visible = True if hasattr(index, "nprobe") else False
|
210 |
+
return nprobe_visible
|
211 |
+
|
212 |
+
|
213 |
+
def change_db_type(query_type: str, subsection_type: str, db_type: str):
|
214 |
+
"""
|
215 |
+
Change the database to search.
|
216 |
+
Args:
|
217 |
+
query_type: The output type.
|
218 |
+
db_type: The database to search.
|
219 |
+
"""
|
220 |
+
now_db[query_type] = db_type
|
221 |
+
|
222 |
+
if query_type == "text":
|
223 |
+
subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
|
224 |
+
else:
|
225 |
+
subsection_update = gr.update(visible=False)
|
226 |
+
|
227 |
+
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
228 |
+
return subsection_update, gr.update(visible=nprobe_visible)
|
229 |
+
|
230 |
+
|
231 |
+
# Build the searching block
|
232 |
+
def build_search_module():
|
233 |
+
gr.Markdown(f"# Search from database")
|
234 |
+
with gr.Row(equal_height=True):
|
235 |
+
with gr.Column():
|
236 |
+
# Set input type
|
237 |
+
input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
# Set output type
|
241 |
+
query_type = gr.Radio(
|
242 |
+
["sequence", "structure", "text"],
|
243 |
+
label="Output type (e.g. 'sequence' means returning qualified sequences)",
|
244 |
+
value="sequence",
|
245 |
+
scale=2,
|
246 |
+
)
|
247 |
+
|
248 |
+
# If the output type is "text", provide an option to choose the subsection of text
|
249 |
+
subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
|
250 |
+
interactive=True, visible=False, scale=0)
|
251 |
+
|
252 |
+
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
|
253 |
+
interactive=True, visible=True, scale=0)
|
254 |
+
|
255 |
+
with gr.Row():
|
256 |
+
# Input box
|
257 |
+
input = gr.Text(label="Input")
|
258 |
+
|
259 |
+
# Provide an upload button to upload a pdb file
|
260 |
+
upload_btn, chain_box = upload_pdb_button(visible=False)
|
261 |
+
upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
|
262 |
+
|
263 |
+
|
264 |
+
# If the index is of IVF type, provide an option to choose the number of clusters.
|
265 |
+
nprobe_visible = check_index_ivf(query_type.value)
|
266 |
+
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
267 |
+
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
268 |
+
|
269 |
+
# Add event listener to output type
|
270 |
+
query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
|
271 |
+
outputs=[subsection_type, nprobe, db_type])
|
272 |
+
|
273 |
+
# Add event listener to db type
|
274 |
+
db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
|
275 |
+
outputs=[subsection_type, nprobe])
|
276 |
+
|
277 |
+
# Choose topk results
|
278 |
+
topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
|
279 |
+
|
280 |
+
# Provide examples
|
281 |
+
examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
|
282 |
+
|
283 |
+
# Add click event to examples
|
284 |
+
examples.click(fn=load_example, inputs=[examples], outputs=input)
|
285 |
+
|
286 |
+
# Change examples based on input type
|
287 |
+
input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
|
288 |
+
|
289 |
+
with gr.Row():
|
290 |
+
search_btn = gr.Button(value="Search")
|
291 |
+
clear_btn = gr.Button(value="Clear")
|
292 |
+
|
293 |
+
with gr.Row():
|
294 |
+
with gr.Column():
|
295 |
+
results = gr.Markdown(label="results", height=450)
|
296 |
+
download_btn = gr.DownloadButton(label="Download results", visible=False)
|
297 |
+
|
298 |
+
# Plot the distribution of scores
|
299 |
+
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
300 |
+
|
301 |
+
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
|
302 |
+
outputs=[results, download_btn, histogram])
|
303 |
+
|
304 |
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
|