LTEnjoy commited on
Commit
bb62e32
·
verified ·
1 Parent(s): cbba759

Update demo/modules/search.py

Browse files
Files changed (1) hide show
  1. demo/modules/search.py +303 -303
demo/modules/search.py CHANGED
@@ -1,304 +1,304 @@
1
- import gradio as gr
2
- import torch
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
-
7
- from scipy.stats import norm
8
- from .init_model import model, all_index, valid_subsections
9
- from .blocks import upload_pdb_button, parse_pdb_file
10
-
11
-
12
- tmp_file_path = "/tmp/results.tsv"
13
- tmp_plot_path = "/tmp/histogram.svg"
14
-
15
- # Samples for input
16
- samples = [
17
- ["Proteins with zinc bindings."],
18
- ["Proteins locating at cell membrane."],
19
- ["Protein that serves as an enzyme."]
20
- ]
21
-
22
- # Databases for different modalities
23
- now_db = {
24
- "sequence": list(all_index["sequence"].keys())[0],
25
- "structure": list(all_index["structure"].keys())[0],
26
- "text": list(all_index["text"].keys())[0]
27
- }
28
-
29
-
30
- def clear_results():
31
- return "", gr.update(visible=False), gr.update(visible=False)
32
-
33
-
34
- def plot(scores) -> None:
35
- """
36
- Plot the distribution of scores and fit a normal distribution.
37
- Args:
38
- scores: List of scores
39
- """
40
- plt.hist(scores, bins=100, density=True, alpha=0.6)
41
- plt.title('Distribution of similarity scores in the database', fontsize=15)
42
- plt.xlabel('Similarity score', fontsize=15)
43
- plt.ylabel('Density', fontsize=15)
44
-
45
- mu, std = norm.fit(scores)
46
-
47
- # Plot the Gaussian
48
- xmin, xmax = plt.xlim()
49
- _, ymax = plt.ylim()
50
- x = np.linspace(xmin, xmax, 100)
51
- p = norm.pdf(x, mu, std)
52
- plt.plot(x, p)
53
-
54
- # Plot total number of scores
55
- plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
56
-
57
- # Convert the plot to svg format
58
- plt.savefig(tmp_plot_path)
59
- plt.cla()
60
-
61
-
62
- # Search from database
63
- def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
64
- input_modality = input_type.replace("sequence", "protein")
65
- with torch.no_grad():
66
- input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
67
-
68
- db = now_db[query_type]
69
- if query_type == "text":
70
- index = all_index["text"][db][subsection_type]["index"]
71
- ids = all_index["text"][db][subsection_type]["ids"]
72
-
73
- else:
74
- index = all_index[query_type][db]["index"]
75
- ids = all_index[query_type][db]["ids"]
76
-
77
- if check_index_ivf(query_type, subsection_type):
78
- if index.nlist < nprobe:
79
- raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
80
- else:
81
- index.nprobe = nprobe
82
-
83
- if topk > index.ntotal:
84
- raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
85
-
86
- # Retrieve all scores to plot the distribution
87
- scores, ranks = index.search(input_embedding, index.ntotal)
88
- scores, ranks = scores[0], ranks[0]
89
-
90
- # Remove inf values
91
- selector = scores > -1
92
- scores = scores[selector]
93
- ranks = ranks[selector]
94
- scores = scores / model.temperature.item()
95
- plot(scores)
96
-
97
- top_scores = scores[:topk]
98
- top_ranks = ranks[:topk]
99
-
100
- # ranks = [list(range(topk))]
101
- # ids = ["P12345"] * topk
102
- # scores = torch.randn(topk).tolist()
103
-
104
- # Write the results to a temporary file for downloading
105
- with open(tmp_file_path, "w") as w:
106
- w.write("Id\tMatching score\n")
107
- for i in range(topk):
108
- rank = top_ranks[i]
109
- w.write(f"{ids[rank]}\t{top_scores[i]}\n")
110
-
111
- # Get topk ids
112
- topk_ids = []
113
- for rank in top_ranks:
114
- now_id = ids[rank]
115
- if query_type == "text":
116
- topk_ids.append(now_id)
117
- else:
118
- if db != "PDB":
119
- # Provide link to uniprot website
120
- topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
121
- else:
122
- # Provide link to pdb website
123
- pdb_id = now_id.split("-")[0]
124
- topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
125
-
126
- limit = 1000
127
- df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
128
- if len(topk_ids) > limit:
129
- info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
130
- index=[1000])
131
- df = pd.concat([df, info_df], axis=0)
132
-
133
- output = df.to_markdown()
134
- return (output,
135
- gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
136
- gr.update(value=tmp_plot_path, visible=True))
137
-
138
-
139
- def change_input_type(choice: str):
140
- # Change examples if input type is changed
141
- global samples
142
- if choice == "text":
143
- samples = [
144
- ["Proteins with zinc bindings."],
145
- ["Proteins locating at cell membrane."],
146
- ["Protein that serves as an enzyme."]
147
- ]
148
-
149
- elif choice == "sequence":
150
- samples = [
151
- ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
152
- ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
153
- ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
154
- ]
155
-
156
- elif choice == "structure":
157
- samples = [
158
- ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
159
- ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
160
- ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
161
- ]
162
-
163
- # Set visibility of upload button
164
- if choice == "text":
165
- visible = False
166
- else:
167
- visible = True
168
-
169
- return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
170
-
171
-
172
- # Load example from dataset
173
- def load_example(example_id):
174
- return samples[example_id][0]
175
-
176
-
177
- # Change the visibility of subsection type
178
- def change_output_type(query_type: str, subsection_type: str):
179
- nprobe_visible = check_index_ivf(query_type, subsection_type)
180
- subsection_visible = True if query_type == "text" else False
181
-
182
- return (
183
- gr.update(visible=subsection_visible),
184
- gr.update(visible=nprobe_visible),
185
- gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
186
- )
187
-
188
-
189
- def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
190
- """
191
- Check if the index is of IVF type.
192
- Args:
193
- index_type: Type of index.
194
- subsection_type: If the "index_type" is "text", get the index based on the subsection type.
195
-
196
- Returns:
197
- Whether the index is of IVF type or not.
198
- """
199
- db = now_db[index_type]
200
- if index_type == "sequence":
201
- index = all_index["sequence"][db]["index"]
202
-
203
- elif index_type == "structure":
204
- index = all_index["structure"][db]["index"]
205
-
206
- elif index_type == "text":
207
- index = all_index["text"][db][subsection_type]["index"]
208
-
209
- nprobe_visible = True if hasattr(index, "nprobe") else False
210
- return nprobe_visible
211
-
212
-
213
- def change_db_type(query_type: str, subsection_type: str, db_type: str):
214
- """
215
- Change the database to search.
216
- Args:
217
- query_type: The output type.
218
- db_type: The database to search.
219
- """
220
- now_db[query_type] = db_type
221
-
222
- if query_type == "text":
223
- subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
224
- else:
225
- subsection_update = gr.update(visible=False)
226
-
227
- nprobe_visible = check_index_ivf(query_type, subsection_type)
228
- return subsection_update, gr.update(visible=nprobe_visible)
229
-
230
-
231
- # Build the searching block
232
- def build_search_module():
233
- gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)")
234
- with gr.Row(equal_height=True):
235
- with gr.Column():
236
- # Set input type
237
- input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
238
-
239
- with gr.Row():
240
- # Set output type
241
- query_type = gr.Radio(
242
- ["sequence", "structure", "text"],
243
- label="Output type (e.g. 'sequence' means returning qualified sequences)",
244
- value="sequence",
245
- scale=2,
246
- )
247
-
248
- # If the output type is "text", provide an option to choose the subsection of text
249
- subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
250
- interactive=True, visible=False, scale=0)
251
-
252
- db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
253
- interactive=True, visible=True, scale=0)
254
-
255
- with gr.Row():
256
- # Input box
257
- input = gr.Text(label="Input")
258
-
259
- # Provide an upload button to upload a pdb file
260
- upload_btn, chain_box = upload_pdb_button(visible=False)
261
- upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
262
-
263
-
264
- # If the index is of IVF type, provide an option to choose the number of clusters.
265
- nprobe_visible = check_index_ivf(query_type.value)
266
- nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
267
- label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
268
-
269
- # Add event listener to output type
270
- query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
271
- outputs=[subsection_type, nprobe, db_type])
272
-
273
- # Add event listener to db type
274
- db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
275
- outputs=[subsection_type, nprobe])
276
-
277
- # Choose topk results
278
- topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
279
-
280
- # Provide examples
281
- examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
282
-
283
- # Add click event to examples
284
- examples.click(fn=load_example, inputs=[examples], outputs=input)
285
-
286
- # Change examples based on input type
287
- input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
288
-
289
- with gr.Row():
290
- search_btn = gr.Button(value="Search")
291
- clear_btn = gr.Button(value="Clear")
292
-
293
- with gr.Row():
294
- with gr.Column():
295
- results = gr.Markdown(label="results", height=450)
296
- download_btn = gr.DownloadButton(label="Download results", visible=False)
297
-
298
- # Plot the distribution of scores
299
- histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
300
-
301
- search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
302
- outputs=[results, download_btn, histogram])
303
-
304
  clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+ from scipy.stats import norm
8
+ from .init_model import model, all_index, valid_subsections
9
+ from .blocks import upload_pdb_button, parse_pdb_file
10
+
11
+
12
+ tmp_file_path = "/tmp/results.tsv"
13
+ tmp_plot_path = "/tmp/histogram.svg"
14
+
15
+ # Samples for input
16
+ samples = [
17
+ ["Proteins with zinc bindings."],
18
+ ["Proteins locating at cell membrane."],
19
+ ["Protein that serves as an enzyme."]
20
+ ]
21
+
22
+ # Databases for different modalities
23
+ now_db = {
24
+ "sequence": list(all_index["sequence"].keys())[0],
25
+ "structure": list(all_index["structure"].keys())[0],
26
+ "text": list(all_index["text"].keys())[0]
27
+ }
28
+
29
+
30
+ def clear_results():
31
+ return "", gr.update(visible=False), gr.update(visible=False)
32
+
33
+
34
+ def plot(scores) -> None:
35
+ """
36
+ Plot the distribution of scores and fit a normal distribution.
37
+ Args:
38
+ scores: List of scores
39
+ """
40
+ plt.hist(scores, bins=100, density=True, alpha=0.6)
41
+ plt.title('Distribution of similarity scores in the database', fontsize=15)
42
+ plt.xlabel('Similarity score', fontsize=15)
43
+ plt.ylabel('Density', fontsize=15)
44
+
45
+ mu, std = norm.fit(scores)
46
+
47
+ # Plot the Gaussian
48
+ xmin, xmax = plt.xlim()
49
+ _, ymax = plt.ylim()
50
+ x = np.linspace(xmin, xmax, 100)
51
+ p = norm.pdf(x, mu, std)
52
+ plt.plot(x, p)
53
+
54
+ # Plot total number of scores
55
+ plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
56
+
57
+ # Convert the plot to svg format
58
+ plt.savefig(tmp_plot_path)
59
+ plt.cla()
60
+
61
+
62
+ # Search from database
63
+ def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
64
+ input_modality = input_type.replace("sequence", "protein")
65
+ with torch.no_grad():
66
+ input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
67
+
68
+ db = now_db[query_type]
69
+ if query_type == "text":
70
+ index = all_index["text"][db][subsection_type]["index"]
71
+ ids = all_index["text"][db][subsection_type]["ids"]
72
+
73
+ else:
74
+ index = all_index[query_type][db]["index"]
75
+ ids = all_index[query_type][db]["ids"]
76
+
77
+ if check_index_ivf(query_type, subsection_type):
78
+ if index.nlist < nprobe:
79
+ raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
80
+ else:
81
+ index.nprobe = nprobe
82
+
83
+ if topk > index.ntotal:
84
+ raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
85
+
86
+ # Retrieve all scores to plot the distribution
87
+ scores, ranks = index.search(input_embedding, index.ntotal)
88
+ scores, ranks = scores[0], ranks[0]
89
+
90
+ # Remove inf values
91
+ selector = scores > -1
92
+ scores = scores[selector]
93
+ ranks = ranks[selector]
94
+ scores = scores / model.temperature.item()
95
+ plot(scores)
96
+
97
+ top_scores = scores[:topk]
98
+ top_ranks = ranks[:topk]
99
+
100
+ # ranks = [list(range(topk))]
101
+ # ids = ["P12345"] * topk
102
+ # scores = torch.randn(topk).tolist()
103
+
104
+ # Write the results to a temporary file for downloading
105
+ with open(tmp_file_path, "w") as w:
106
+ w.write("Id\tMatching score\n")
107
+ for i in range(topk):
108
+ rank = top_ranks[i]
109
+ w.write(f"{ids[rank]}\t{top_scores[i]}\n")
110
+
111
+ # Get topk ids
112
+ topk_ids = []
113
+ for rank in top_ranks:
114
+ now_id = ids[rank]
115
+ if query_type == "text":
116
+ topk_ids.append(now_id)
117
+ else:
118
+ if db != "PDB":
119
+ # Provide link to uniprot website
120
+ topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
121
+ else:
122
+ # Provide link to pdb website
123
+ pdb_id = now_id.split("-")[0]
124
+ topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
125
+
126
+ limit = 1000
127
+ df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
128
+ if len(topk_ids) > limit:
129
+ info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
130
+ index=[1000])
131
+ df = pd.concat([df, info_df], axis=0)
132
+
133
+ output = df.to_markdown()
134
+ return (output,
135
+ gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
136
+ gr.update(value=tmp_plot_path, visible=True))
137
+
138
+
139
+ def change_input_type(choice: str):
140
+ # Change examples if input type is changed
141
+ global samples
142
+ if choice == "text":
143
+ samples = [
144
+ ["Proteins with zinc bindings."],
145
+ ["Proteins locating at cell membrane."],
146
+ ["Protein that serves as an enzyme."]
147
+ ]
148
+
149
+ elif choice == "sequence":
150
+ samples = [
151
+ ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
152
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
153
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
154
+ ]
155
+
156
+ elif choice == "structure":
157
+ samples = [
158
+ ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
159
+ ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
160
+ ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
161
+ ]
162
+
163
+ # Set visibility of upload button
164
+ if choice == "text":
165
+ visible = False
166
+ else:
167
+ visible = True
168
+
169
+ return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
170
+
171
+
172
+ # Load example from dataset
173
+ def load_example(example_id):
174
+ return samples[example_id][0]
175
+
176
+
177
+ # Change the visibility of subsection type
178
+ def change_output_type(query_type: str, subsection_type: str):
179
+ nprobe_visible = check_index_ivf(query_type, subsection_type)
180
+ subsection_visible = True if query_type == "text" else False
181
+
182
+ return (
183
+ gr.update(visible=subsection_visible),
184
+ gr.update(visible=nprobe_visible),
185
+ gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
186
+ )
187
+
188
+
189
+ def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
190
+ """
191
+ Check if the index is of IVF type.
192
+ Args:
193
+ index_type: Type of index.
194
+ subsection_type: If the "index_type" is "text", get the index based on the subsection type.
195
+
196
+ Returns:
197
+ Whether the index is of IVF type or not.
198
+ """
199
+ db = now_db[index_type]
200
+ if index_type == "sequence":
201
+ index = all_index["sequence"][db]["index"]
202
+
203
+ elif index_type == "structure":
204
+ index = all_index["structure"][db]["index"]
205
+
206
+ elif index_type == "text":
207
+ index = all_index["text"][db][subsection_type]["index"]
208
+
209
+ nprobe_visible = True if hasattr(index, "nprobe") else False
210
+ return nprobe_visible
211
+
212
+
213
+ def change_db_type(query_type: str, subsection_type: str, db_type: str):
214
+ """
215
+ Change the database to search.
216
+ Args:
217
+ query_type: The output type.
218
+ db_type: The database to search.
219
+ """
220
+ now_db[query_type] = db_type
221
+
222
+ if query_type == "text":
223
+ subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
224
+ else:
225
+ subsection_update = gr.update(visible=False)
226
+
227
+ nprobe_visible = check_index_ivf(query_type, subsection_type)
228
+ return subsection_update, gr.update(visible=nprobe_visible)
229
+
230
+
231
+ # Build the searching block
232
+ def build_search_module():
233
+ gr.Markdown(f"# Search from database")
234
+ with gr.Row(equal_height=True):
235
+ with gr.Column():
236
+ # Set input type
237
+ input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
238
+
239
+ with gr.Row():
240
+ # Set output type
241
+ query_type = gr.Radio(
242
+ ["sequence", "structure", "text"],
243
+ label="Output type (e.g. 'sequence' means returning qualified sequences)",
244
+ value="sequence",
245
+ scale=2,
246
+ )
247
+
248
+ # If the output type is "text", provide an option to choose the subsection of text
249
+ subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
250
+ interactive=True, visible=False, scale=0)
251
+
252
+ db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
253
+ interactive=True, visible=True, scale=0)
254
+
255
+ with gr.Row():
256
+ # Input box
257
+ input = gr.Text(label="Input")
258
+
259
+ # Provide an upload button to upload a pdb file
260
+ upload_btn, chain_box = upload_pdb_button(visible=False)
261
+ upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
262
+
263
+
264
+ # If the index is of IVF type, provide an option to choose the number of clusters.
265
+ nprobe_visible = check_index_ivf(query_type.value)
266
+ nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
267
+ label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
268
+
269
+ # Add event listener to output type
270
+ query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
271
+ outputs=[subsection_type, nprobe, db_type])
272
+
273
+ # Add event listener to db type
274
+ db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
275
+ outputs=[subsection_type, nprobe])
276
+
277
+ # Choose topk results
278
+ topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
279
+
280
+ # Provide examples
281
+ examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
282
+
283
+ # Add click event to examples
284
+ examples.click(fn=load_example, inputs=[examples], outputs=input)
285
+
286
+ # Change examples based on input type
287
+ input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
288
+
289
+ with gr.Row():
290
+ search_btn = gr.Button(value="Search")
291
+ clear_btn = gr.Button(value="Clear")
292
+
293
+ with gr.Row():
294
+ with gr.Column():
295
+ results = gr.Markdown(label="results", height=450)
296
+ download_btn = gr.DownloadButton(label="Download results", visible=False)
297
+
298
+ # Plot the distribution of scores
299
+ histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
300
+
301
+ search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
302
+ outputs=[results, download_btn, histogram])
303
+
304
  clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])