LTEnjoy commited on
Commit
f88d1cb
1 Parent(s): db29cff

Update demo/modules/search.py

Browse files
Files changed (1) hide show
  1. demo/modules/search.py +15 -20
demo/modules/search.py CHANGED
@@ -33,13 +33,6 @@ samples = {
33
  ],
34
  }
35
 
36
- # Databases for different modalities
37
- now_db = {
38
- "sequence": list(all_index["sequence"].keys())[0],
39
- "structure": list(all_index["structure"].keys())[0],
40
- "text": list(all_index["text"].keys())[0]
41
- }
42
-
43
 
44
  def clear_results():
45
  return "", gr.update(visible=False), gr.update(visible=False)
@@ -75,6 +68,8 @@ def plot(scores) -> None:
75
 
76
  # Search from database
77
  def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
 
 
78
  input_modality = input_type.replace("sequence", "protein")
79
  with torch.no_grad():
80
  input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
@@ -169,17 +164,18 @@ def load_example(example_id):
169
 
170
  # Change the visibility of subsection type
171
  def change_output_type(query_type: str, subsection_type: str):
172
- nprobe_visible = check_index_ivf(query_type, subsection_type)
 
173
  subsection_visible = True if query_type == "text" else False
174
-
175
  return (
176
- gr.update(visible=subsection_visible),
177
  gr.update(visible=nprobe_visible),
178
- gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
179
  )
180
 
181
 
182
- def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
183
  """
184
  Check if the index is of IVF type.
185
  Args:
@@ -189,7 +185,6 @@ def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
189
  Returns:
190
  Whether the index is of IVF type or not.
191
  """
192
- db = now_db[index_type]
193
  if index_type == "sequence":
194
  index = all_index["sequence"][db]["index"]
195
 
@@ -211,14 +206,12 @@ def change_db_type(query_type: str, subsection_type: str, db_type: str):
211
  query_type: The output type.
212
  db_type: The database to search.
213
  """
214
- now_db[query_type] = db_type
215
-
216
  if query_type == "text":
217
- subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
218
  else:
219
  subsection_update = gr.update(visible=False)
220
 
221
- nprobe_visible = check_index_ivf(query_type, subsection_type)
222
  return subsection_update, gr.update(visible=nprobe_visible)
223
 
224
 
@@ -240,10 +233,12 @@ def build_search_module():
240
  )
241
 
242
  # If the output type is "text", provide an option to choose the subsection of text
243
- subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
 
 
244
  interactive=True, visible=False, scale=0)
245
 
246
- db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
247
  interactive=True, visible=True, scale=0)
248
 
249
  with gr.Row():
@@ -256,7 +251,7 @@ def build_search_module():
256
 
257
 
258
  # If the index is of IVF type, provide an option to choose the number of clusters.
259
- nprobe_visible = check_index_ivf(query_type.value)
260
  nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
261
  label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
262
 
 
33
  ],
34
  }
35
 
 
 
 
 
 
 
 
36
 
37
  def clear_results():
38
  return "", gr.update(visible=False), gr.update(visible=False)
 
68
 
69
  # Search from database
70
  def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
71
+ print(f"Input type: {input_type}\n Output type: {query_type}\nDatabase: {db}\nSubsection: {subsection_type}")
72
+
73
  input_modality = input_type.replace("sequence", "protein")
74
  with torch.no_grad():
75
  input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
 
164
 
165
  # Change the visibility of subsection type
166
  def change_output_type(query_type: str, subsection_type: str):
167
+ db_type = list(all_index[query_type].keys())[0]
168
+ nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
169
  subsection_visible = True if query_type == "text" else False
170
+
171
  return (
172
+ gr.update(visible=subsection_visible),
173
  gr.update(visible=nprobe_visible),
174
+ gr.update(choices=list(all_index[query_type].keys()), value=db_type)
175
  )
176
 
177
 
178
+ def check_index_ivf(index_type: str, db: str, subsection_type: str = None) -> bool:
179
  """
180
  Check if the index is of IVF type.
181
  Args:
 
185
  Returns:
186
  Whether the index is of IVF type or not.
187
  """
 
188
  if index_type == "sequence":
189
  index = all_index["sequence"][db]["index"]
190
 
 
206
  query_type: The output type.
207
  db_type: The database to search.
208
  """
 
 
209
  if query_type == "text":
210
+ subsection_update = gr.update(choices=list(valid_subsections[db_type]), value="Function")
211
  else:
212
  subsection_update = gr.update(visible=False)
213
 
214
+ nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
215
  return subsection_update, gr.update(visible=nprobe_visible)
216
 
217
 
 
233
  )
234
 
235
  # If the output type is "text", provide an option to choose the subsection of text
236
+ text_db = list(all_index["text"].keys())[0]
237
+ sequence_db = list(all_index["sequence"].keys())[0]
238
+ subsection_type = gr.Dropdown(valid_subsections[text_db], label="Subsection of text", value="Function",
239
  interactive=True, visible=False, scale=0)
240
 
241
+ db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=sequence_db,
242
  interactive=True, visible=True, scale=0)
243
 
244
  with gr.Row():
 
251
 
252
 
253
  # If the index is of IVF type, provide an option to choose the number of clusters.
254
+ nprobe_visible = check_index_ivf(query_type.value, db_type.value)
255
  nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
256
  label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
257