File size: 12,053 Bytes
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993f1a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
bdb7241
bb62e32
 
bdb7241
 
bb62e32
bdb7241
 
 
 
 
 
bb62e32
 
 
 
 
 
 
 
 
 
bdb7241
bb62e32
 
 
 
 
 
 
993f1a4
f88d1cb
 
bb62e32
 
 
 
 
 
 
 
 
 
 
 
5c9db02
bb62e32
 
 
 
63c4ae7
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993f1a4
bb62e32
 
 
 
993f1a4
bb62e32
 
 
 
f88d1cb
 
bb62e32
f88d1cb
bb62e32
f88d1cb
bb62e32
f88d1cb
bb62e32
 
 
f88d1cb
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43977ca
 
 
bb62e32
 
 
 
 
 
 
 
 
 
f88d1cb
bb62e32
 
 
f88d1cb
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f88d1cb
 
 
bb62e32
 
f88d1cb
bb62e32
 
 
 
 
 
 
42e763d
bb62e32
 
 
 
f88d1cb
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db29cff
bb62e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993f1a4
bb62e32
 
52da96f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import gradio as gr
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import norm
from .init_model import model, all_index, valid_subsections
from .blocks import upload_pdb_button, parse_pdb_file


tmp_file_path = "/tmp/results.tsv"
tmp_plot_path = "/tmp/histogram.svg"

# Samples for input
samples = {
    "sequence": [
            ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
            ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
            ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
        ],

    "structure": [
            ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
            ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
            ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
        ],

    "text": [
        ["Proteins with zinc bindings."],
        ["Proteins locating at cell membrane."],
        ["Protein that serves as an enzyme."]
    ],
}


def clear_results():
    return "", gr.update(visible=False), gr.update(visible=False)


def plot(scores) -> None:
    """
    Plot the distribution of scores and fit a normal distribution.
    Args:
        scores: List of scores
    """
    plt.hist(scores, bins=100, density=True, alpha=0.6)
    plt.title('Distribution of similarity scores in the database', fontsize=15)
    plt.xlabel('Similarity score', fontsize=15)
    plt.ylabel('Density', fontsize=15)
    y_ed = plt.gca().get_ylim()[-1]
    plt.ylim(-0.05, y_ed)

    # Add note
    x_st = plt.gca().get_xlim()[0]
    text = ("Note: For the \"UniRef50\" and \"Uncharacterized\" databases, the figure illustrates\n "
            "only top-ranked clusters (identified using Faiss), whereas for other databases, it\n "
            "displays the distribution across all samples.")
    plt.text(x_st, -0.04, text, fontsize=8)
    mu, std = norm.fit(scores)

    # Plot the Gaussian
    xmin, xmax = plt.xlim()
    _, ymax = plt.ylim()
    x = np.linspace(xmin, xmax, 100)
    p = norm.pdf(x, mu, std)
    plt.plot(x, p)
    
    # Plot total number of scores
    plt.text(xmax, 0.9*ymax, f"Total number: {len(scores)}", ha='right', fontsize=12)

    # Convert the plot to svg format
    plt.savefig(tmp_plot_path)
    plt.cla()


# Search from database
def search(input: str, nprobe: int, topk: int, input_type: str, query_type: str, subsection_type: str, db: str):
    print(f"Input type: {input_type}\n Output type: {query_type}\nDatabase: {db}\nSubsection: {subsection_type}")
    
    input_modality = input_type.replace("sequence", "protein")
    with torch.no_grad():
        input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()

    if query_type == "text":
        index = all_index["text"][db][subsection_type]["index"]
        ids = all_index["text"][db][subsection_type]["ids"]

    else:
        index = all_index[query_type][db]["index"]
        ids = all_index[query_type][db]["ids"]
        
    if hasattr(index, "nprobe"):
        if index.nlist < nprobe:
            raise gr.Error(f"The number of clusters to search must be less than or equal to the number of clusters in the index ({index.nlist}).")
        else:
            index.nprobe = nprobe

    if topk > index.ntotal:
        raise gr.Error(f"You cannot retrieve more than the database size ({index.ntotal}).")
    
    # Retrieve all scores to plot the distribution
    scores, ranks = index.search(input_embedding, index.ntotal)
    scores, ranks = scores[0], ranks[0]
    
    # Remove inf values
    selector = scores > -1
    scores = scores[selector]
    ranks = ranks[selector]
    scores = scores / model.temperature.item()
    plot(scores)
    
    top_scores = scores[:topk]
    top_ranks = ranks[:topk]
    
    # ranks = [list(range(topk))]
    # ids = ["P12345"] * topk
    # scores = torch.randn(topk).tolist()
    
    # Write the results to a temporary file for downloading
    with open(tmp_file_path, "w") as w:
        w.write("Id\tMatching score\n")
        for i in range(topk):
            rank = top_ranks[i]
            w.write(f"{ids[rank]}\t{top_scores[i]}\n")
    
    # Get topk ids
    topk_ids = []
    for rank in top_ranks:
        now_id = ids[rank]
        if query_type == "text":
            topk_ids.append(now_id)
        else:
            if db != "PDB":
                # Provide link to uniprot website
                topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
            else:
                # Provide link to pdb website
                pdb_id = now_id.split("-")[0]
                topk_ids.append(f"[{now_id}](https://www.rcsb.org/structure/{pdb_id})")
    
    limit = 1000
    df = pd.DataFrame({"Id": topk_ids[:limit], "Matching score": top_scores[:limit]})
    if len(topk_ids) > limit:
        info_df = pd.DataFrame({"Id": ["Download the file to check all results"], "Matching score": ["..."]},
                               index=[1000])
        df = pd.concat([df, info_df], axis=0)
    
    output = df.to_markdown()
    return (output,
            gr.DownloadButton(label="Download results", value=tmp_file_path, visible=True, scale=0),
            gr.update(value=tmp_plot_path, visible=True))


def change_input_type(choice: str):
    # Change examples if input type is changed
    global samples
    
    # Set visibility of upload button
    if choice == "text":
        visible = False
    else:
        visible = True
    
    return gr.update(samples=samples[choice]), "", gr.update(visible=visible), gr.update(visible=visible)


# Load example from dataset
def load_example(example_id):
    return example_id[0]
 
 
# Change the visibility of subsection type
def change_output_type(query_type: str, subsection_type: str):
    db_type = list(all_index[query_type].keys())[0]
    nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
    subsection_visible = True if query_type == "text" else False

    return (
        gr.update(visible=subsection_visible),
        gr.update(visible=nprobe_visible),
        gr.update(choices=list(all_index[query_type].keys()), value=db_type)
    )


def check_index_ivf(index_type: str, db: str, subsection_type: str = None) -> bool:
    """
    Check if the index is of IVF type.
    Args:
        index_type: Type of index.
        subsection_type: If the "index_type" is "text", get the index based on the subsection type.

    Returns:
        Whether the index is of IVF type or not.
    """
    if index_type == "sequence":
        index = all_index["sequence"][db]["index"]
    
    elif index_type == "structure":
        index = all_index["structure"][db]["index"]
    
    elif index_type == "text":
        index = all_index["text"][db][subsection_type]["index"]
    
    # nprobe_visible = True if hasattr(index, "nprobe") else False
    # return nprobe_visible
    return False


def change_db_type(query_type: str, subsection_type: str, db_type: str):
    """
    Change the database to search.
    Args:
        query_type: The output type.
        db_type: The database to search.
    """
    if query_type == "text":
        subsection_update = gr.update(choices=list(valid_subsections[db_type]), value="Function")
    else:
        subsection_update = gr.update(visible=False)
    
    nprobe_visible = check_index_ivf(query_type, db_type, subsection_type)
    return subsection_update, gr.update(visible=nprobe_visible)


# Build the searching block
def build_search_module():
    gr.Markdown(f"# Search from database")
    with gr.Row(equal_height=True):
        with gr.Column():
            # Set input type
            input_type = gr.Radio(["sequence", "structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")

            with gr.Row():
                # Set output type
                query_type = gr.Radio(
                    ["sequence", "structure", "text"],
                    label="Output type (e.g. 'sequence' means returning qualified sequences)",
                    value="sequence",
                    scale=2,
                )
            
                # If the output type is "text", provide an option to choose the subsection of text
                text_db = list(all_index["text"].keys())[0]
                sequence_db = list(all_index["sequence"].keys())[0]
                subsection_type = gr.Dropdown(valid_subsections[text_db], label="Subsection of text", value="Function",
                                              interactive=True, visible=False, scale=0)
                
                db_type = gr.Dropdown(all_index["sequence"].keys(), label="Database", value=sequence_db,
                                              interactive=True, visible=True, scale=0)

            with gr.Row():
                # Input box
                input = gr.Text(label="Input")
                
                # Provide an upload button to upload a pdb file
                upload_btn, chain_box = upload_pdb_button(visible=False, chain_visible=False)
                upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn, chain_box], outputs=[input])
            
            
            # If the index is of IVF type, provide an option to choose the number of clusters.
            nprobe_visible = check_index_ivf(query_type.value, db_type.value)
            nprobe = gr.Slider(1, 1000000, 1000,  step=1, visible=nprobe_visible,
                               label="Number of clusters to search (lower value for faster search and higher value for more accurate search)")
            
            # Add event listener to output type
            query_type.change(fn=change_output_type, inputs=[query_type, subsection_type],
                              outputs=[subsection_type, nprobe, db_type])
            
            # Add event listener to db type
            db_type.change(fn=change_db_type, inputs=[query_type, subsection_type, db_type],
                           outputs=[subsection_type, nprobe])
            
            # Choose topk results
            topk = gr.Slider(1, 1000000, 5,  step=1, label="Retrieve top k results")

            # Provide examples
            examples = gr.Dataset(samples=samples["text"], components=[input], label="Input examples")
            
            # Add click event to examples
            examples.click(fn=load_example, inputs=[examples], outputs=input)
            
            # Change examples based on input type
            input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn, chain_box])
            
            with gr.Row():
                search_btn = gr.Button(value="Search")
                clear_btn = gr.Button(value="Clear")
        
        with gr.Row():
            with gr.Column():
                results = gr.Markdown(label="results", height=450)
                download_btn = gr.DownloadButton(label="Download results", visible=False)
            
                # Plot the distribution of scores
                histogram = gr.Image(label="Histogram of matching scores", type="filepath", scale=1, visible=False)
            
        search_btn.click(fn=search, inputs=[input, nprobe, topk, input_type, query_type, subsection_type, db_type],
                      outputs=[results, download_btn, histogram])
        
        clear_btn.click(fn=clear_results, outputs=[results, download_btn, histogram])