|
import gradio as gr
|
|
import torch
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
from scipy.stats import norm
|
|
from .init_model import model, all_index, valid_subsections
|
|
from .blocks import upload_pdb_button, parse_pdb_file
|
|
|
|
|
|
tmp_file_path = "/tmp/results.tsv"
|
|
tmp_plot_path = "/tmp/histogram.svg"
|
|
|
|
|
|
samples = [
|
|
["Proteins with zinc bindings."],
|
|
["Proteins locating at cell membrane."],
|
|
["Protein that serves as an enzyme."]
|
|
]
|
|
|
|
|
|
now_db = {
|
|
"sequence": list(all_index["sequence"].keys())[0],
|
|
"structure": list(all_index["structure"].keys())[0],
|
|
"text": list(all_index["text"].keys())[0]
|
|
}
|
|
|
|
|
|
def clear_results():
|
|
return "", gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
def plot(scores) -> None:
|
|
"""
|
|
Plot the distribution of scores and fit a normal distribution.
|
|
Args:
|
|
scores: List of scores
|
|
"""
|
|
plt.hist(scores, bins=100, density=True, alpha=0.6)
|
|
plt.title('Distribution of similarity scores in the database', fontsize=15)
|
|
plt.xlabel('Similarity score', fontsize=15)
|
|
plt.ylabel('Density', fontsize=15)
|
|
|
|
mu, std = norm.fit(scores)
|
|
|
|
|
|
xmin, xmax = plt.xlim()
|
|
_, ymax = plt.ylim()
|
|
x = np.linspace(xmin, xmax, 100)
|
|
p = norm.pdf(x, mu, std)
|
|
plt.plot(x, p)
|
|
|
|
|
|
plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)
|
|
|
|
|
|
plt.savefig(tmp_plot_path)
|
|
plt.cla()
|
|
|
|
|
|
|
|
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str):
|
|
input_modality = input_type.replace("sequence", "protein")
|
|
with torch.no_grad():
|
|
input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
|
|
|
|
db = now_db[query_type]
|
|
if query_type == "text":
|
|
index = all_index["text"][db][subsection_type]["index"]
|
|
ids = all_index["text"][db][subsection_type]["ids"]
|
|
|
|
else:
|
|
index = all_index[query_type][db]["index"]
|
|
ids = all_index[query_type][db]["ids"]
|
|
|
|
if check_index_ivf(query_type, subsection_type):
|
|
if index.nlist < nprobe:
|
|
raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
|
|
else:
|
|
index.nprobe = nprobe
|
|
|
|
if topk > index.ntotal:
|
|
raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
|
|
|
|
|
|
scores, ranks = index.search(input_embedding, index.ntotal)
|
|
scores, ranks = scores[0], ranks[0]
|
|
|
|
|
|
selector = scores > -1
|
|
scores = scores[selector]
|
|
ranks = ranks[selector]
|
|
scores = scores / model.temperature.item()
|
|
plot(scores)
|
|
|
|
top_scores = scores[:topk]
|
|
top_ranks = ranks[:topk]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(tmp_file_path, "w") as w:
|
|
w.write("Id\tMatching score\n")
|
|
for i in range(topk):
|
|
rank = top_ranks[i]
|
|
w.write(f"{ids[rank]}\t{top_scores[i]}\n")
|
|
|
|
|
|
topk_ids = []
|
|
for rank in top_ranks:
|
|
now_id = ids[rank]
|
|
if query_type == "text":
|
|
topk_ids.append(now_id)
|
|
else:
|
|
if db != "PDB":
|
|
|
|
topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
|
|
else:
|
|
|
|
pdb_id = now_id.split("-")[0]
|
|
topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
|
|
|
|
limit = 1000
|
|
df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
|
|
if len(topk_ids) > limit:
|
|
info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
|
|
index=[1000])
|
|
df = pd.concat([df, info_df], axis=0)
|
|
|
|
output = df.to_markdown()
|
|
return (output,
|
|
gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
|
|
gr.update(value=tmp_plot_path, visible=True))
|
|
|
|
|
|
def change_input_type(choice: str):
|
|
|
|
global samples
|
|
if choice == "text":
|
|
samples = [
|
|
["Proteins with zinc bindings."],
|
|
["Proteins locating at cell membrane."],
|
|
["Protein that serves as an enzyme."]
|
|
]
|
|
|
|
elif choice == "sequence":
|
|
samples = [
|
|
["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
|
|
["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
|
|
["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
|
|
]
|
|
|
|
elif choice == "structure":
|
|
samples = [
|
|
["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
|
|
["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
|
|
["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
|
|
]
|
|
|
|
|
|
if choice == "text":
|
|
visible = False
|
|
else:
|
|
visible = True
|
|
|
|
return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible)
|
|
|
|
|
|
|
|
def load_example(example_id):
|
|
return samples[example_id][0]
|
|
|
|
|
|
|
|
def change_output_type(query_type: str, subsection_type: str):
|
|
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
|
subsection_visible = True if query_type == "text" else False
|
|
|
|
return (
|
|
gr.update(visible=subsection_visible),
|
|
gr.update(visible=nprobe_visible),
|
|
gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type])
|
|
)
|
|
|
|
|
|
def check_index_ivf(index_type: str, subsection_type: str = None) -> bool:
|
|
"""
|
|
Check if the index is of IVF type.
|
|
Args:
|
|
index_type: Type of index.
|
|
subsection_type: If the "index_type" is "text", get the index based on the subsection type.
|
|
|
|
Returns:
|
|
Whether the index is of IVF type or not.
|
|
"""
|
|
db = now_db[index_type]
|
|
if index_type == "sequence":
|
|
index = all_index["sequence"][db]["index"]
|
|
|
|
elif index_type == "structure":
|
|
index = all_index["structure"][db]["index"]
|
|
|
|
elif index_type == "text":
|
|
index = all_index["text"][db][subsection_type]["index"]
|
|
|
|
nprobe_visible = True if hasattr(index, "nprobe") else False
|
|
return nprobe_visible
|
|
|
|
|
|
def change_db_type(query_type: str, subsection_type: str, db_type: str):
|
|
"""
|
|
Change the database to search.
|
|
Args:
|
|
query_type: The output type.
|
|
db_type: The database to search.
|
|
"""
|
|
now_db[query_type] = db_type
|
|
|
|
if query_type == "text":
|
|
subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function")
|
|
else:
|
|
subsection_update = gr.update(visible=False)
|
|
|
|
nprobe_visible = check_index_ivf(query_type, subsection_type)
|
|
return subsection_update, gr.update(visible=nprobe_visible)
|
|
|
|
|
|
|
|
def build_search_module():
|
|
gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)")
|
|
with gr.Row(equal_height=True):
|
|
with gr.Column():
|
|
|
|
input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
|
|
|
|
with gr.Row():
|
|
|
|
query_type = gr.Radio(
|
|
["sequence", "structure", "text"],
|
|
label="Output type (e.g. 'sequence' means returning qualified sequences)",
|
|
value="sequence",
|
|
scale=2,
|
|
)
|
|
|
|
|
|
subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function",
|
|
interactive=True, visible=False, scale=0)
|
|
|
|
db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"],
|
|
interactive=True, visible=True, scale=0)
|
|
|
|
with gr.Row():
|
|
|
|
input = gr.Text(label="Input")
|
|
|
|
|
|
upload_btn, chain_box = upload_pdb_button(visible=False)
|
|
upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
|
|
|
|
|
|
|
|
nprobe_visible = check_index_ivf(query_type.value)
|
|
nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible,
|
|
label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
|
|
|
|
|
|
query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
|
|
outputs=[subsection_type, nprobe, db_type])
|
|
|
|
|
|
db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
|
|
outputs=[subsection_type, nprobe])
|
|
|
|
|
|
topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results")
|
|
|
|
|
|
examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
|
|
|
|
|
|
examples.click(fn=load_example, inputs=[examples], outputs=input)
|
|
|
|
|
|
input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
|
|
|
|
with gr.Row():
|
|
search_btn = gr.Button(value="Search")
|
|
clear_btn = gr.Button(value="Clear")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
results = gr.Markdown(label="results", height=450)
|
|
download_btn = gr.DownloadButton(label="Download results", visible=False)
|
|
|
|
|
|
histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
|
|
|
|
search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type],
|
|
outputs=[results, download_btn, histogram])
|
|
|
|
clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram]) |