Update demo/modules/search.py
Browse files- demo/modules/search.py +15 -20
demo/modules/search.py
CHANGED
@@ -33,13 +33,6 @@ samples = {
|
|
33 |
],
|
34 |
}
|
35 |
|
36 |
-
# Databases for different modalities
|
37 |
-
now_db = {
|
38 |
-
"sequence": list(all_index["sequence"].keys())[0],
|
39 |
-
"structure": list(all_index["structure"].keys())[0],
|
40 |
-
"text": list(all_index["text"].keys())[0]
|
41 |
-
}
|
42 |
-
|
43 |
|
44 |
def clear_results():
|
45 |
return "", gr.update(visible=False), gr.update(visible=False)
|
@@ -75,6 +68,8 @@ def plot(scores) -> None:
|
|
75 |
|
76 |
# Search from database
|
77 |
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
|
|
|
|
|
78 |
input_modality = input_type.replace("sequence", "protein")
|
79 |
with torch.no_grad():
|
80 |
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
@@ -169,17 +164,18 @@ def load_example(example_id):
|
|
169 |
|
170 |
# Change the visibility of subsection type
|
171 |
def change_output_type(query_type: str, subsection_type: str):
|
172 |
-
|
|
|
173 |
subsection_visible = True if query_type == "text" else False
|
174 |
-
|
175 |
return (
|
176 |
-
gr.update(visible=subsection_visible),
|
177 |
gr.update(visible=nprobe_visible),
|
178 |
-
gr.update(choices=list(all_index[query_type].keys()), value=
|
179 |
)
|
180 |
|
181 |
|
182 |
-
def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
183 |
"""
|
184 |
Check if the index is of IVF type.
|
185 |
Args:
|
@@ -189,7 +185,6 @@ def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
|
189 |
Returns:
|
190 |
Whether the index is of IVF type or not.
|
191 |
"""
|
192 |
-
db = now_db[index_type]
|
193 |
if index_type == "sequence":
|
194 |
index = all_index["sequence"][db]["index"]
|
195 |
|
@@ -211,14 +206,12 @@ def change_db_type(query_type: str, subsection_type: str, db_type: str):
|
|
211 |
query_type: The output type.
|
212 |
db_type: The database to search.
|
213 |
"""
|
214 |
-
now_db[query_type] = db_type
|
215 |
-
|
216 |
if query_type == "text":
|
217 |
-
subsection_update = gr.update(choices=list(valid_subsections[
|
218 |
else:
|
219 |
subsection_update = gr.update(visible=False)
|
220 |
|
221 |
-
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
222 |
return subsection_update, gr.update(visible=nprobe_visible)
|
223 |
|
224 |
|
@@ -240,10 +233,12 @@ def build_search_module():
|
|
240 |
)
|
241 |
|
242 |
# If the output type is "text", provide an option to choose the subsection of text
|
243 |
-
|
|
|
|
|
244 |
interactive=True, visible=False, scale=0)
|
245 |
|
246 |
-
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=
|
247 |
interactive=True, visible=True, scale=0)
|
248 |
|
249 |
with gr.Row():
|
@@ -256,7 +251,7 @@ def build_search_module():
|
|
256 |
|
257 |
|
258 |
# If the index is of IVF type, provide an option to choose the number of clusters.
|
259 |
-
nprobe_visible = check_index_ivf(query_type.value)
|
260 |
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
261 |
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
262 |
|
|
|
33 |
],
|
34 |
}
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def clear_results():
|
38 |
return "", gr.update(visible=False), gr.update(visible=False)
|
|
|
68 |
|
69 |
# Search from database
|
70 |
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
|
71 |
+
print(f"Input type: {input_type}\n Output type: {query_type}\nDatabase: {db}\nSubsection: {subsection_type}")
|
72 |
+
|
73 |
input_modality = input_type.replace("sequence", "protein")
|
74 |
with torch.no_grad():
|
75 |
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
|
|
164 |
|
165 |
# Change the visibility of subsection type
|
166 |
def change_output_type(query_type: str, subsection_type: str):
|
167 |
+
db_type = list(all_index[query_type].keys())[0]
|
168 |
+
nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
|
169 |
subsection_visible = True if query_type == "text" else False
|
170 |
+
|
171 |
return (
|
172 |
+
gr.update(visible=subsection_visible),
|
173 |
gr.update(visible=nprobe_visible),
|
174 |
+
gr.update(choices=list(all_index[query_type].keys()), value=db_type)
|
175 |
)
|
176 |
|
177 |
|
178 |
+
def check_index_ivf(index_type: str, db: str, subsection_type: str = None) -> bool:
|
179 |
"""
|
180 |
Check if the index is of IVF type.
|
181 |
Args:
|
|
|
185 |
Returns:
|
186 |
Whether the index is of IVF type or not.
|
187 |
"""
|
|
|
188 |
if index_type == "sequence":
|
189 |
index = all_index["sequence"][db]["index"]
|
190 |
|
|
|
206 |
query_type: The output type.
|
207 |
db_type: The database to search.
|
208 |
"""
|
|
|
|
|
209 |
if query_type == "text":
|
210 |
+
subsection_update = gr.update(choices=list(valid_subsections[db_type]), value="Function")
|
211 |
else:
|
212 |
subsection_update = gr.update(visible=False)
|
213 |
|
214 |
+
nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
|
215 |
return subsection_update, gr.update(visible=nprobe_visible)
|
216 |
|
217 |
|
|
|
233 |
)
|
234 |
|
235 |
# If the output type is "text", provide an option to choose the subsection of text
|
236 |
+
text_db = list(all_index["text"].keys())[0]
|
237 |
+
sequence_db = list(all_index["sequence"].keys())[0]
|
238 |
+
subsection_type = gr.Dropdown(valid_subsections[text_db], label="Subsection of text", value="Function",
|
239 |
interactive=True, visible=False, scale=0)
|
240 |
|
241 |
+
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=sequence_db,
|
242 |
interactive=True, visible=True, scale=0)
|
243 |
|
244 |
with gr.Row():
|
|
|
251 |
|
252 |
|
253 |
# If the index is of IVF type, provide an option to choose the number of clusters.
|
254 |
+
nprobe_visible = check_index_ivf(query_type.value, db_type.value)
|
255 |
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
256 |
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
257 |
|