import gradio as gr import torch import pandas as pd import matplotlib.pyplot as plt import numpy as np from scipy.stats import norm from .init_model import model, all_index, valid_subsections from .blocks import upload_pdb_button, parse_pdb_file tmp_file_path = "/tmp/results.tsv" tmp_plot_path = "/tmp/histogram.svg" # Samples for input samples = [ ["Proteins with zinc bindings."], ["Proteins locating at cell membrane."], ["Protein that serves as an enzyme."] ] # Databases for different modalities now_db = { "sequence": list(all_index["sequence"].keys())[0], "structure": list(all_index["structure"].keys())[0], "text": list(all_index["text"].keys())[0] } def clear_results(): return "", gr.update(visible=False), gr.update(visible=False) def plot(scores) -> None: """ Plot the distribution of scores and fit a normal distribution. Args: scores: List of scores """ plt.hist(scores, bins=100, density=True, alpha=0.6) plt.title('Distribution of similarity scores in the database', fontsize=15) plt.xlabel('Similarity score', fontsize=15) plt.ylabel('Density', fontsize=15) mu, std = norm.fit(scores) # Plot the Gaussian xmin, xmax = plt.xlim() _, ymax = plt.ylim() x = np.linspace(xmin, xmax, 100) p = norm.pdf(x, mu, std) plt.plot(x, p) # Plot total number of scores plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12) # Convert the plot to svg format plt.savefig(tmp_plot_path) plt.cla() # Search from database def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str): input_modality = input_type.replace("sequence", "protein") with torch.no_grad(): input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy() db = now_db[query_type] if query_type == "text": index = all_index["text"][db][subsection_type]["index"] ids = all_index["text"][db][subsection_type]["ids"] else: index = all_index[query_type][db]["index"] ids = all_index[query_type][db]["ids"] if check_index_ivf(query_type, subsection_type): if index.nlist < nprobe: raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).") else: index.nprobe = nprobe if topk > index.ntotal: raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).") # Retrieve all scores to plot the distribution scores, ranks = index.search(input_embedding, index.ntotal) scores, ranks = scores[0], ranks[0] # Remove inf values selector = scores > -1 scores = scores[selector] ranks = ranks[selector] scores = scores / model.temperature.item() plot(scores) top_scores = scores[:topk] top_ranks = ranks[:topk] # ranks = [list(range(topk))] # ids = ["P12345"] * topk # scores = torch.randn(topk).tolist() # Write the results to a temporary file for downloading with open(tmp_file_path, "w") as w: w.write("Id\tMatching score\n") for i in range(topk): rank = top_ranks[i] w.write(f"{ids[rank]}\t{top_scores[i]}\n") # Get topk ids topk_ids = [] for rank in top_ranks: now_id = ids[rank] if query_type == "text": topk_ids.append(now_id) else: if db != "PDB": # Provide link to uniprot website topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})") else: # Provide link to pdb website pdb_id = now_id.split("-")[0] topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})") limit = 1000 df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]}) if len(topk_ids) > limit: info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]}, index=[1000]) df = pd.concat([df, info_df], axis=0) output = df.to_markdown() return (output, gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0), gr.update(value=tmp_plot_path, visible=True)) def change_input_type(choice: str): # Change examples if input type is changed global samples if choice == "text": samples = [ ["Proteins with zinc bindings."], ["Proteins locating at cell membrane."], ["Protein that serves as an enzyme."] ] elif choice == "sequence": samples = [ ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"], ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"], ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"] ] elif choice == "structure": samples = [ ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"], ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"], ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"] ] # Set visibility of upload button if choice == "text": visible = False else: visible = True return gr.update(samples=samples), "", gr.update(visible=visible), gr.update(visible=visible) # Load example from dataset def load_example(example_id): return samples[example_id][0] # Change the visibility of subsection type def change_output_type(query_type: str, subsection_type: str): nprobe_visible = check_index_ivf(query_type, subsection_type) subsection_visible = True if query_type == "text" else False return ( gr.update(visible=subsection_visible), gr.update(visible=nprobe_visible), gr.update(choices=list(all_index[query_type].keys()), value=now_db[query_type]) ) def check_index_ivf(index_type: str, subsection_type: str = None) -> bool: """ Check if the index is of IVF type. Args: index_type: Type of index. subsection_type: If the "index_type" is "text", get the index based on the subsection type. Returns: Whether the index is of IVF type or not. """ db = now_db[index_type] if index_type == "sequence": index = all_index["sequence"][db]["index"] elif index_type == "structure": index = all_index["structure"][db]["index"] elif index_type == "text": index = all_index["text"][db][subsection_type]["index"] nprobe_visible = True if hasattr(index, "nprobe") else False return nprobe_visible def change_db_type(query_type: str, subsection_type: str, db_type: str): """ Change the database to search. Args: query_type: The output type. db_type: The database to search. """ now_db[query_type] = db_type if query_type == "text": subsection_update = gr.update(choices=list(valid_subsections[now_db["text"]]), value="Function") else: subsection_update = gr.update(visible=False) nprobe_visible = check_index_ivf(query_type, subsection_type) return subsection_update, gr.update(visible=nprobe_visible) # Build the searching block def build_search_module(): gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)") with gr.Row(equal_height=True): with gr.Column(): # Set input type input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text") with gr.Row(): # Set output type query_type = gr.Radio( ["sequence", "structure", "text"], label="Output type (e.g. 'sequence' means returning qualified sequences)", value="sequence", scale=2, ) # If the output type is "text", provide an option to choose the subsection of text subsection_type = gr.Dropdown(valid_subsections[now_db["text"]], label="Subsection of text", value="Function", interactive=True, visible=False, scale=0) db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=now_db["sequence"], interactive=True, visible=True, scale=0) with gr.Row(): # Input box input = gr.Text(label="Input") # Provide an upload button to upload a pdb file upload_btn, chain_box = upload_pdb_button(visible=False) upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input]) # If the index is of IVF type, provide an option to choose the number of clusters. nprobe_visible = check_index_ivf(query_type.value) nprobe = gr.Slider(1, 1000000, 1000, step=1, visible=nprobe_visible, label="Number of clusters to search (lower value for faster search and higher value for more accurate search)") # Add event listener to output type query_type.change(fn=change_output_type, inputs=[query_type, subsection_type], outputs=[subsection_type, nprobe, db_type]) # Add event listener to db type db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type], outputs=[subsection_type, nprobe]) # Choose topk results topk = gr.Slider(1, 1000000, 5, step=1, label="Retrieve top k results") # Provide examples examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples") # Add click event to examples examples.click(fn=load_example, inputs=[examples], outputs=input) # Change examples based on input type input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box]) with gr.Row(): search_btn = gr.Button(value="Search") clear_btn = gr.Button(value="Clear") with gr.Row(): with gr.Column(): results = gr.Markdown(label="results", height=450) download_btn = gr.DownloadButton(label="Download results", visible=False) # Plot the distribution of scores histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False) search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type], outputs=[results, download_btn, histogram]) clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])