Spaces:
Sleeping
Sleeping
import gradio as gr | |
import argparse | |
import os | |
from UltraFlow.models.sbap import * | |
from UltraFlow import commons | |
import warnings | |
warnings.filterwarnings("ignore") | |
model_dir = './workdir/gradio/' | |
checkpoint = 'checkpointbest_valid_1.ckp' | |
total_num = 0 | |
def get_config(model_dir): | |
# get config | |
config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml')) | |
# get device | |
# config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need) | |
config.device = 'cpu' | |
# set random seed | |
commons.set_seed(config.seed) | |
return config | |
def load_graph_dim(lig_graph, prot_graph, model_config): | |
lig_node_dim = lig_graph.ndata['h'].shape[1] | |
lig_edge_dim = lig_graph.edata['e'].shape[1] | |
if model_config.data.add_chemical_bond_feats: | |
lig_edge_dim += lig_graph.edata['bond_type'].shape[1] | |
pro_node_dim = prot_graph.ndata['h'].shape[1] | |
pro_edge_dim = prot_graph.edata['e'].shape[1] | |
inter_edge_dim = 15 | |
if model_config.data.use_mean_node_features: | |
lig_node_dim += 5 | |
pro_node_dim += 5 | |
return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim | |
def trans_device(data, device): | |
return [x if isinstance(x, list) else x.to(device) for x in data] | |
def get_data(model_config, ligand_path, protein_path): | |
molecular_representation = commons.read_molecules_inference(ligand_path, protein_path, | |
model_config.data.prot_graph_type, | |
model_config.data.chaincut) | |
lig_coords, lig_features, lig_edges, lig_node_type, \ | |
prot_coords, prot_features, prot_edges, prot_node_type, \ | |
sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation | |
lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, | |
max_neighbors=model_config.data.lig_max_neighbors, | |
cutoff=model_config.data.ligcut) | |
prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, | |
sec_features, alpha_c_coords, c_coords, n_coords, | |
max_neighbor=model_config.data.prot_max_neighbors, | |
cutoff=model_config.data.protcut) | |
prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid) | |
prot_graph.chain_index = chain_index_valid | |
inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords, | |
max_neighbor=model_config.data.inter_max_neighbors, | |
min_neighbor=model_config.data.inter_min_neighbors, | |
cutoff=model_config.data.intercut) | |
# set feats dim | |
lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config) | |
model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim | |
model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim | |
model_config.model.inter_edge_dim = inter_edge_dim | |
if model_config.data.add_chemical_bond_feats: | |
lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1) | |
if model_config.data.use_mean_node_features: | |
lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1) | |
prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1) | |
label = torch.tensor(-100).unsqueeze(dim=-1) | |
item = [0] | |
assay_des = torch.zeros(0) | |
IC50_f, K_f = [True], [True] | |
data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f) | |
return trans_device(data, model_config.device) | |
def get_models(model_config, model_dir, checkpoint): | |
if model_config.train.multi_task: | |
model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device) | |
else: | |
model = globals()[model_config.model.model_type](model_config).to(model_config.device) | |
checkpoint_path = os.path.join(model_dir, checkpoint) | |
print("Load checkpoint from %s" % checkpoint_path) | |
state = torch.load(checkpoint_path, map_location=model_config.device) | |
model.load_state_dict(state["model"]) | |
model = model.eval() | |
return model | |
def mbp_scoring(ligand_path, protein_path): | |
data_example = get_data(model_config, ligand_path, protein_path) | |
_, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False) | |
return affinity_pred_IC50.item(), affinity_pred_K.item() | |
def test(ligand, protein): | |
global total_num | |
total_num = total_num + 1 | |
print(f'total num: {total_num}') | |
try: | |
IC50, K = mbp_scoring(ligand.name, protein.name) | |
print(f'ligand file name: {os.path.basename(ligand.name)},' | |
f' protein file name: {os.path.basename(protein.name)},' | |
f' IC50: {IC50}, K: {K}') | |
return '{:.2f}'.format(IC50), '{:.2f}'.format(K) | |
except Exception as e: | |
# print(e) | |
return e, e | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction | |
## Welcome to the MBP demo ! | |
- Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time. | |
- If you encounter any issues, please reach out to [email protected]. | |
- All codes and data are available on the online platform https://github.com/jiaxianyan/MBP. | |
""") | |
with gr.Row(): | |
ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.") | |
protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.") | |
IC50 = gr.Textbox(label="Predicted IC50 Value") | |
K = gr.Textbox(label="Predicted K Value") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring") | |
gr.Markdown("## Input Examples") | |
gr.Examples( | |
examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']], | |
inputs=[ligand, protein], | |
# outputs=[IC50, K], | |
fn=test, | |
cache_examples=False, | |
) | |
model_config = get_config(model_dir) | |
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb') | |
model = get_models(model_config, model_dir, checkpoint) | |
demo.launch(share=False) | |