Spaces:
Running
Running
File size: 5,692 Bytes
3ad8be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")
model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'
def get_config(model_dir):
# get config
config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))
# get device
# config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
config.device = 'cpu'
# set random seed
commons.set_seed(config.seed)
return config
def load_graph_dim(lig_graph, prot_graph, model_config):
lig_node_dim = lig_graph.ndata['h'].shape[1]
lig_edge_dim = lig_graph.edata['e'].shape[1]
if model_config.data.add_chemical_bond_feats:
lig_edge_dim += lig_graph.edata['bond_type'].shape[1]
pro_node_dim = prot_graph.ndata['h'].shape[1]
pro_edge_dim = prot_graph.edata['e'].shape[1]
inter_edge_dim = 15
if model_config.data.use_mean_node_features:
lig_node_dim += 5
pro_node_dim += 5
return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim
def trans_device(data, device):
return [x if isinstance(x, list) else x.to(device) for x in data]
def get_data(model_config, ligand_path, protein_path):
molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
model_config.data.prot_graph_type,
model_config.data.chaincut)
lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords, prot_features, prot_edges, prot_node_type, \
sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation
lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
max_neighbors=model_config.data.lig_max_neighbors,
cutoff=model_config.data.ligcut)
prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
sec_features, alpha_c_coords, c_coords, n_coords,
max_neighbor=model_config.data.prot_max_neighbors,
cutoff=model_config.data.protcut)
prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
prot_graph.chain_index = chain_index_valid
inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
max_neighbor=model_config.data.inter_max_neighbors,
min_neighbor=model_config.data.inter_min_neighbors,
cutoff=model_config.data.intercut)
# set feats dim
lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
model_config.model.inter_edge_dim = inter_edge_dim
if model_config.data.add_chemical_bond_feats:
lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)
if model_config.data.use_mean_node_features:
lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)
label = torch.tensor(-100).unsqueeze(dim=-1)
item = [0]
assay_des = torch.zeros(0)
IC50_f, K_f = [True], [True]
data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
return trans_device(data, model_config.device)
def get_models(model_config, model_dir, checkpoint):
if model_config.train.multi_task:
model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
else:
model = globals()[model_config.model.model_type](model_config).to(model_config.device)
checkpoint_path = os.path.join(model_dir, checkpoint)
print("Load checkpoint from %s" % checkpoint_path)
state = torch.load(checkpoint_path, map_location=model_config.device)
model.load_state_dict(state["model"])
model = model.eval()
return model
def mbp_scoring(ligand_path, protein_path):
model_config = get_config(model_dir)
data_example = get_data(model_config, ligand_path, protein_path)
model = get_models(model_config, model_dir, checkpoint)
_, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
return affinity_pred_IC50.item(), affinity_pred_K.item()
def test(ligand, protein):
try:
IC50, K = mbp_scoring(ligand.name, protein.name)
return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
except Exception as e:
return 'Please set input correctly', 'Please set input correctly'
with gr.Blocks() as demo:
ligand = gr.File(label="ligand")
protein = gr.File(label="Protein, please use .pdb files")
IC50 = gr.Textbox(label="Predicted IC50 Value")
K = gr.Textbox(label="Predicted K Value")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")
demo.launch()
|