Spaces:
Running
Running
File size: 5,912 Bytes
3ad8be1 82143cd 3ad8be1 b58cc31 3ad8be1 3713ecb 3ad8be1 3713ecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")
model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'
def get_config(model_dir):
# get config
config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))
# get device
# config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
config.device = 'cpu'
# set random seed
commons.set_seed(config.seed)
return config
def load_graph_dim(lig_graph, prot_graph, model_config):
lig_node_dim = lig_graph.ndata['h'].shape[1]
lig_edge_dim = lig_graph.edata['e'].shape[1]
if model_config.data.add_chemical_bond_feats:
lig_edge_dim += lig_graph.edata['bond_type'].shape[1]
pro_node_dim = prot_graph.ndata['h'].shape[1]
pro_edge_dim = prot_graph.edata['e'].shape[1]
inter_edge_dim = 15
if model_config.data.use_mean_node_features:
lig_node_dim += 5
pro_node_dim += 5
return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim
def trans_device(data, device):
return [x if isinstance(x, list) else x.to(device) for x in data]
def get_data(model_config, ligand_path, protein_path):
molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
model_config.data.prot_graph_type,
model_config.data.chaincut)
lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords, prot_features, prot_edges, prot_node_type, \
sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation
lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
max_neighbors=model_config.data.lig_max_neighbors,
cutoff=model_config.data.ligcut)
prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
sec_features, alpha_c_coords, c_coords, n_coords,
max_neighbor=model_config.data.prot_max_neighbors,
cutoff=model_config.data.protcut)
prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
prot_graph.chain_index = chain_index_valid
inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
max_neighbor=model_config.data.inter_max_neighbors,
min_neighbor=model_config.data.inter_min_neighbors,
cutoff=model_config.data.intercut)
# set feats dim
lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
model_config.model.inter_edge_dim = inter_edge_dim
if model_config.data.add_chemical_bond_feats:
lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)
if model_config.data.use_mean_node_features:
lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)
label = torch.tensor(-100).unsqueeze(dim=-1)
item = [0]
assay_des = torch.zeros(0)
IC50_f, K_f = [True], [True]
data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
return trans_device(data, model_config.device)
def get_models(model_config, model_dir, checkpoint):
if model_config.train.multi_task:
model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
else:
model = globals()[model_config.model.model_type](model_config).to(model_config.device)
checkpoint_path = os.path.join(model_dir, checkpoint)
print("Load checkpoint from %s" % checkpoint_path)
state = torch.load(checkpoint_path, map_location=model_config.device)
model.load_state_dict(state["model"])
model = model.eval()
return model
def mbp_scoring(ligand_path, protein_path):
data_example = get_data(model_config, ligand_path, protein_path)
_, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
return affinity_pred_IC50.item(), affinity_pred_K.item()
def test(ligand, protein):
try:
IC50, K = mbp_scoring(ligand.name, protein.name)
print(f'ligand file name: {os.path.basename(ligand)},'
f' protein file name: {os.path.basename(protein)},'
f' IC50: {IC50}, K: {K}')
return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
except Exception as e:
return e, e
with gr.Blocks() as demo:
ligand = gr.File(label="ligand")
protein = gr.File(label="Protein, please use .pdb files")
IC50 = gr.Textbox(label="Predicted IC50 Value")
K = gr.Textbox(label="Predicted K Value")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")
model_config = get_config(model_dir)
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb')
model = get_models(model_config, model_dir, checkpoint)
demo.launch()
|