File size: 5,912 Bytes
3ad8be1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82143cd
 
 
3ad8be1
 
b58cc31
3ad8be1
 
 
 
 
 
 
 
 
 
3713ecb
 
 
 
3ad8be1
3713ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")


model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'

def get_config(model_dir):
    # get config
    config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))

    # get device
    # config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
    config.device = 'cpu'

    # set random seed
    commons.set_seed(config.seed)

    return config

def load_graph_dim(lig_graph, prot_graph, model_config):
    lig_node_dim = lig_graph.ndata['h'].shape[1]
    lig_edge_dim = lig_graph.edata['e'].shape[1]
    if model_config.data.add_chemical_bond_feats:
        lig_edge_dim += lig_graph.edata['bond_type'].shape[1]

    pro_node_dim = prot_graph.ndata['h'].shape[1]
    pro_edge_dim = prot_graph.edata['e'].shape[1]
    inter_edge_dim = 15

    if model_config.data.use_mean_node_features:
        lig_node_dim += 5
        pro_node_dim += 5

    return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim


def trans_device(data, device):
    return [x if isinstance(x, list) else x.to(device) for x in data]

def get_data(model_config, ligand_path, protein_path):
    molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
                                                                model_config.data.prot_graph_type,
                                                                model_config.data.chaincut)

    lig_coords, lig_features, lig_edges, lig_node_type, \
    prot_coords, prot_features, prot_edges, prot_node_type, \
    sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation

    lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
                                                max_neighbors=model_config.data.lig_max_neighbors,
                                                cutoff=model_config.data.ligcut)

    prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
                                                          sec_features, alpha_c_coords, c_coords, n_coords,
                                                          max_neighbor=model_config.data.prot_max_neighbors,
                                                          cutoff=model_config.data.protcut)

    prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
    prot_graph.chain_index = chain_index_valid


    inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
                                                     max_neighbor=model_config.data.inter_max_neighbors,
                                                     min_neighbor=model_config.data.inter_min_neighbors,
                                                     cutoff=model_config.data.intercut)

    # set feats dim
    lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
    model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
    model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
    model_config.model.inter_edge_dim = inter_edge_dim

    if model_config.data.add_chemical_bond_feats:
        lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)

    if model_config.data.use_mean_node_features:
        lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
        prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)


    label = torch.tensor(-100).unsqueeze(dim=-1)
    item = [0]
    assay_des = torch.zeros(0)

    IC50_f, K_f = [True], [True]

    data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
    return trans_device(data, model_config.device)

def get_models(model_config, model_dir, checkpoint):
    if model_config.train.multi_task:
        model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
    else:
        model = globals()[model_config.model.model_type](model_config).to(model_config.device)

    checkpoint_path = os.path.join(model_dir, checkpoint)
    print("Load checkpoint from %s" % checkpoint_path)
    state = torch.load(checkpoint_path, map_location=model_config.device)
    model.load_state_dict(state["model"])
    model = model.eval()

    return model

def mbp_scoring(ligand_path, protein_path):
    data_example = get_data(model_config, ligand_path, protein_path)
    _, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
    return affinity_pred_IC50.item(), affinity_pred_K.item()


def test(ligand, protein):
    try:
        IC50, K = mbp_scoring(ligand.name, protein.name)
        print(f'ligand file name: {os.path.basename(ligand)},'
              f' protein file name: {os.path.basename(protein)},'
              f' IC50: {IC50}, K: {K}')
        return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
    except Exception as e:
        return e, e

with gr.Blocks() as demo:
    ligand = gr.File(label="ligand")
    protein = gr.File(label="Protein, please use .pdb files")
    IC50 = gr.Textbox(label="Predicted IC50 Value")
    K = gr.Textbox(label="Predicted K Value")

    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")

model_config = get_config(model_dir)
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb')
model = get_models(model_config, model_dir, checkpoint)

demo.launch()