File size: 6,975 Bytes
3ad8be1
 
 
 
 
 
 
 
 
 
 
b412584
3ad8be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b412584
 
 
ca33195
 
b412584
 
 
ca33195
 
b412584
52e74e1
3ad8be1
 
aad7754
 
 
72b4c06
aad7754
 
865e35e
aad7754
 
 
18379ef
 
aad7754
3ad8be1
 
 
 
 
 
aad7754
 
 
 
 
 
 
 
 
b412584
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")


model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'
total_num = 0

def get_config(model_dir):
    # get config
    config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))

    # get device
    # config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
    config.device = 'cpu'

    # set random seed
    commons.set_seed(config.seed)

    return config

def load_graph_dim(lig_graph, prot_graph, model_config):
    lig_node_dim = lig_graph.ndata['h'].shape[1]
    lig_edge_dim = lig_graph.edata['e'].shape[1]
    if model_config.data.add_chemical_bond_feats:
        lig_edge_dim += lig_graph.edata['bond_type'].shape[1]

    pro_node_dim = prot_graph.ndata['h'].shape[1]
    pro_edge_dim = prot_graph.edata['e'].shape[1]
    inter_edge_dim = 15

    if model_config.data.use_mean_node_features:
        lig_node_dim += 5
        pro_node_dim += 5

    return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim


def trans_device(data, device):
    return [x if isinstance(x, list) else x.to(device) for x in data]

def get_data(model_config, ligand_path, protein_path):
    molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
                                                                model_config.data.prot_graph_type,
                                                                model_config.data.chaincut)

    lig_coords, lig_features, lig_edges, lig_node_type, \
    prot_coords, prot_features, prot_edges, prot_node_type, \
    sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation

    lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
                                                max_neighbors=model_config.data.lig_max_neighbors,
                                                cutoff=model_config.data.ligcut)

    prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
                                                          sec_features, alpha_c_coords, c_coords, n_coords,
                                                          max_neighbor=model_config.data.prot_max_neighbors,
                                                          cutoff=model_config.data.protcut)

    prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
    prot_graph.chain_index = chain_index_valid


    inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
                                                     max_neighbor=model_config.data.inter_max_neighbors,
                                                     min_neighbor=model_config.data.inter_min_neighbors,
                                                     cutoff=model_config.data.intercut)

    # set feats dim
    lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
    model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
    model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
    model_config.model.inter_edge_dim = inter_edge_dim

    if model_config.data.add_chemical_bond_feats:
        lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)

    if model_config.data.use_mean_node_features:
        lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
        prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)


    label = torch.tensor(-100).unsqueeze(dim=-1)
    item = [0]
    assay_des = torch.zeros(0)

    IC50_f, K_f = [True], [True]

    data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
    return trans_device(data, model_config.device)

def get_models(model_config, model_dir, checkpoint):
    if model_config.train.multi_task:
        model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
    else:
        model = globals()[model_config.model.model_type](model_config).to(model_config.device)

    checkpoint_path = os.path.join(model_dir, checkpoint)
    print("Load checkpoint from %s" % checkpoint_path)
    state = torch.load(checkpoint_path, map_location=model_config.device)
    model.load_state_dict(state["model"])
    model = model.eval()

    return model

def mbp_scoring(ligand_path, protein_path):
    data_example = get_data(model_config, ligand_path, protein_path)
    _, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
    return affinity_pred_IC50.item(), affinity_pred_K.item()


def test(ligand, protein):
    global total_num
    total_num = total_num + 1
    print(f'total num: {total_num}')
    try:
        IC50, K = mbp_scoring(ligand.name, protein.name)
        print(f'ligand file name: {os.path.basename(ligand.name)},'
              f' protein file name: {os.path.basename(protein.name)},'
              f' IC50: {IC50}, K: {K}')
        return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
    except Exception as e:
        # print(e)
        return e, e

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction
    ## Welcome to the MBP demo !
    - Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time.
    - If you encounter any issues, please reach out to [email protected].
    - All codes and data are available on the online platform https://github.com/jiaxianyan/MBP. 
    """)

    with gr.Row():
        ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.")
        protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.")

    IC50 = gr.Textbox(label="Predicted IC50 Value")
    K = gr.Textbox(label="Predicted K Value")

    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")

    gr.Markdown("## Input Examples")
    gr.Examples(
        examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']],
        inputs=[ligand, protein],
        # outputs=[IC50, K],
        fn=test,
        cache_examples=False,
    )

model_config = get_config(model_dir)
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb')
model = get_models(model_config, model_dir, checkpoint)

demo.launch(share=False)