import gradio as gr import argparse import os from UltraFlow.models.sbap import * from UltraFlow import commons import warnings warnings.filterwarnings("ignore") model_dir = './workdir/gradio/' checkpoint = 'checkpointbest_valid_1.ckp' total_num = 0 def get_config(model_dir): # get config config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml')) # get device # config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need) config.device = 'cpu' # set random seed commons.set_seed(config.seed) return config def load_graph_dim(lig_graph, prot_graph, model_config): lig_node_dim = lig_graph.ndata['h'].shape[1] lig_edge_dim = lig_graph.edata['e'].shape[1] if model_config.data.add_chemical_bond_feats: lig_edge_dim += lig_graph.edata['bond_type'].shape[1] pro_node_dim = prot_graph.ndata['h'].shape[1] pro_edge_dim = prot_graph.edata['e'].shape[1] inter_edge_dim = 15 if model_config.data.use_mean_node_features: lig_node_dim += 5 pro_node_dim += 5 return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim def trans_device(data, device): return [x if isinstance(x, list) else x.to(device) for x in data] def get_data(model_config, ligand_path, protein_path): molecular_representation = commons.read_molecules_inference(ligand_path, protein_path, model_config.data.prot_graph_type, model_config.data.chaincut) lig_coords, lig_features, lig_edges, lig_node_type, \ prot_coords, prot_features, prot_edges, prot_node_type, \ sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, max_neighbors=model_config.data.lig_max_neighbors, cutoff=model_config.data.ligcut) prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, sec_features, alpha_c_coords, c_coords, n_coords, max_neighbor=model_config.data.prot_max_neighbors, cutoff=model_config.data.protcut) prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid) prot_graph.chain_index = chain_index_valid inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords, max_neighbor=model_config.data.inter_max_neighbors, min_neighbor=model_config.data.inter_min_neighbors, cutoff=model_config.data.intercut) # set feats dim lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config) model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim model_config.model.inter_edge_dim = inter_edge_dim if model_config.data.add_chemical_bond_feats: lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1) if model_config.data.use_mean_node_features: lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1) prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1) label = torch.tensor(-100).unsqueeze(dim=-1) item = [0] assay_des = torch.zeros(0) IC50_f, K_f = [True], [True] data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f) return trans_device(data, model_config.device) def get_models(model_config, model_dir, checkpoint): if model_config.train.multi_task: model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device) else: model = globals()[model_config.model.model_type](model_config).to(model_config.device) checkpoint_path = os.path.join(model_dir, checkpoint) print("Load checkpoint from %s" % checkpoint_path) state = torch.load(checkpoint_path, map_location=model_config.device) model.load_state_dict(state["model"]) model = model.eval() return model def mbp_scoring(ligand_path, protein_path): data_example = get_data(model_config, ligand_path, protein_path) _, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False) return affinity_pred_IC50.item(), affinity_pred_K.item() def test(ligand, protein): global total_num total_num = total_num + 1 print(f'total num: {total_num}') try: IC50, K = mbp_scoring(ligand.name, protein.name) print(f'ligand file name: {os.path.basename(ligand.name)},' f' protein file name: {os.path.basename(protein.name)},' f' IC50: {IC50}, K: {K}') return '{:.2f}'.format(IC50), '{:.2f}'.format(K) except Exception as e: # print(e) return e, e with gr.Blocks() as demo: gr.Markdown( """ # Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction ## Welcome to the MBP demo ! - Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time. - If you encounter any issues, please reach out to jiaxianyan@mail.ustc.edu.cn. - All codes and data are available on the online platform https://github.com/jiaxianyan/MBP. """) with gr.Row(): ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.") protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.") IC50 = gr.Textbox(label="Predicted IC50 Value") K = gr.Textbox(label="Predicted K Value") submit_btn = gr.Button("Submit") submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring") gr.Markdown("## Input Examples") gr.Examples( examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']], inputs=[ligand, protein], # outputs=[IC50, K], fn=test, cache_examples=False, ) model_config = get_config(model_dir) data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb') model = get_models(model_config, model_dir, checkpoint) demo.launch(share=False)