LTEnjoy commited on
Commit
993f1a4
·
verified ·
1 Parent(s): a1d996c

Update demo/modules/search.py

Browse files
Files changed (1) hide show
  1. demo/modules/search.py +23 -30
demo/modules/search.py CHANGED
@@ -13,11 +13,25 @@ tmp_file_path = "/tmp/results.tsv"
13
  tmp_plot_path = "/tmp/histogram.svg"
14
 
15
  # Samples for input
16
- samples = [
17
- ["Proteins with zinc bindings."],
18
- ["Proteins locating at cell membrane."],
19
- ["Protein that serves as an enzyme."]
20
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Databases for different modalities
23
  now_db = {
@@ -60,12 +74,11 @@ def plot(scores) -> None:
60
 
61
 
62
  # Search from database
63
- def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
64
  input_modality = input_type.replace("sequence", "protein")
65
  with torch.no_grad():
66
  input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
67
 
68
- db = now_db[query_type]
69
  if query_type == "text":
70
  index = all_index["text"][db][subsection_type]["index"]
71
  ids = all_index["text"][db][subsection_type]["ids"]
@@ -139,26 +152,6 @@ def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str,
139
  def change_input_type(choice: str):
140
  # Change examples if input type is changed
141
  global samples
142
- if choice == "text":
143
- samples = [
144
- ["Proteins with zinc bindings."],
145
- ["Proteins locating at cell membrane."],
146
- ["Protein that serves as an enzyme."]
147
- ]
148
-
149
- elif choice == "sequence":
150
- samples = [
151
- ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
152
- ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
153
- ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
154
- ]
155
-
156
- elif choice == "structure":
157
- samples = [
158
- ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
159
- ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
160
- ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
161
- ]
162
 
163
  # Set visibility of upload button
164
  if choice == "text":
@@ -166,12 +159,12 @@ def change_input_type(choice: str):
166
  else:
167
  visible = True
168
 
169
- return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
170
 
171
 
172
  # Load example from dataset
173
  def load_example(example_id):
174
- return samples[example_id][0]
175
 
176
 
177
  # Change the visibility of subsection type
@@ -299,7 +292,7 @@ def build_search_module():
299
  # Plot the distribution of scores
300
  histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
301
 
302
- search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
303
  outputs=[results, download_btn, histogram])
304
 
305
  clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])
 
13
  tmp_plot_path = "/tmp/histogram.svg"
14
 
15
  # Samples for input
16
+ samples = {
17
+ "sequence": [
18
+ ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
19
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
20
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
21
+ ],
22
+
23
+ "structure": [
24
+ ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
25
+ ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
26
+ ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
27
+ ],
28
+
29
+ "text": [
30
+ ["Proteins with zinc bindings."],
31
+ ["Proteins locating at cell membrane."],
32
+ ["Protein that serves as an enzyme."]
33
+ ],
34
+ }
35
 
36
  # Databases for different modalities
37
  now_db = {
 
74
 
75
 
76
  # Search from database
77
+ def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
78
  input_modality = input_type.replace("sequence", "protein")
79
  with torch.no_grad():
80
  input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
81
 
 
82
  if query_type == "text":
83
  index = all_index["text"][db][subsection_type]["index"]
84
  ids = all_index["text"][db][subsection_type]["ids"]
 
152
  def change_input_type(choice: str):
153
  # Change examples if input type is changed
154
  global samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Set visibility of upload button
157
  if choice == "text":
 
159
  else:
160
  visible = True
161
 
162
+ return gr.update(samples=samples[choice]), "", gr.update(visible=visible), gr.update(visible=visible)
163
 
164
 
165
  # Load example from dataset
166
  def load_example(example_id):
167
+ return example_id[0]
168
 
169
 
170
  # Change the visibility of subsection type
 
292
  # Plot the distribution of scores
293
  histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
294
 
295
+ search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type, db_type],
296
  outputs=[results, download_btn, histogram])
297
 
298
  clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])