Spaces:
Running
Running
File size: 6,975 Bytes
3ad8be1 b412584 3ad8be1 b412584 ca33195 b412584 ca33195 b412584 52e74e1 3ad8be1 aad7754 72b4c06 aad7754 865e35e aad7754 18379ef aad7754 3ad8be1 aad7754 b412584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")
model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'
total_num = 0
def get_config(model_dir):
# get config
config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))
# get device
# config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
config.device = 'cpu'
# set random seed
commons.set_seed(config.seed)
return config
def load_graph_dim(lig_graph, prot_graph, model_config):
lig_node_dim = lig_graph.ndata['h'].shape[1]
lig_edge_dim = lig_graph.edata['e'].shape[1]
if model_config.data.add_chemical_bond_feats:
lig_edge_dim += lig_graph.edata['bond_type'].shape[1]
pro_node_dim = prot_graph.ndata['h'].shape[1]
pro_edge_dim = prot_graph.edata['e'].shape[1]
inter_edge_dim = 15
if model_config.data.use_mean_node_features:
lig_node_dim += 5
pro_node_dim += 5
return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim
def trans_device(data, device):
return [x if isinstance(x, list) else x.to(device) for x in data]
def get_data(model_config, ligand_path, protein_path):
molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
model_config.data.prot_graph_type,
model_config.data.chaincut)
lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords, prot_features, prot_edges, prot_node_type, \
sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation
lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
max_neighbors=model_config.data.lig_max_neighbors,
cutoff=model_config.data.ligcut)
prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
sec_features, alpha_c_coords, c_coords, n_coords,
max_neighbor=model_config.data.prot_max_neighbors,
cutoff=model_config.data.protcut)
prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
prot_graph.chain_index = chain_index_valid
inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
max_neighbor=model_config.data.inter_max_neighbors,
min_neighbor=model_config.data.inter_min_neighbors,
cutoff=model_config.data.intercut)
# set feats dim
lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
model_config.model.inter_edge_dim = inter_edge_dim
if model_config.data.add_chemical_bond_feats:
lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)
if model_config.data.use_mean_node_features:
lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)
label = torch.tensor(-100).unsqueeze(dim=-1)
item = [0]
assay_des = torch.zeros(0)
IC50_f, K_f = [True], [True]
data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
return trans_device(data, model_config.device)
def get_models(model_config, model_dir, checkpoint):
if model_config.train.multi_task:
model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
else:
model = globals()[model_config.model.model_type](model_config).to(model_config.device)
checkpoint_path = os.path.join(model_dir, checkpoint)
print("Load checkpoint from %s" % checkpoint_path)
state = torch.load(checkpoint_path, map_location=model_config.device)
model.load_state_dict(state["model"])
model = model.eval()
return model
def mbp_scoring(ligand_path, protein_path):
data_example = get_data(model_config, ligand_path, protein_path)
_, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
return affinity_pred_IC50.item(), affinity_pred_K.item()
def test(ligand, protein):
global total_num
total_num = total_num + 1
print(f'total num: {total_num}')
try:
IC50, K = mbp_scoring(ligand.name, protein.name)
print(f'ligand file name: {os.path.basename(ligand.name)},'
f' protein file name: {os.path.basename(protein.name)},'
f' IC50: {IC50}, K: {K}')
return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
except Exception as e:
# print(e)
return e, e
with gr.Blocks() as demo:
gr.Markdown(
"""
# Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction
## Welcome to the MBP demo !
- Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time.
- If you encounter any issues, please reach out to [email protected].
- All codes and data are available on the online platform https://github.com/jiaxianyan/MBP.
""")
with gr.Row():
ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.")
protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.")
IC50 = gr.Textbox(label="Predicted IC50 Value")
K = gr.Textbox(label="Predicted K Value")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")
gr.Markdown("## Input Examples")
gr.Examples(
examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']],
inputs=[ligand, protein],
# outputs=[IC50, K],
fn=test,
cache_examples=False,
)
model_config = get_config(model_dir)
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb')
model = get_models(model_config, model_dir, checkpoint)
demo.launch(share=False)
|