import gradio as gr import argparse import os from UltraFlow.models.sbap import * from UltraFlow import commons import warnings warnings.filterwarnings("ignore") model_dir = './workdir/gradio/' checkpoint = 'checkpointbest_valid_1.ckp' def get_config(model_dir): # get config config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml')) # get device # config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need) config.device = 'cpu' # set random seed commons.set_seed(config.seed) return config def load_graph_dim(lig_graph, prot_graph, model_config): lig_node_dim = lig_graph.ndata['h'].shape[1] lig_edge_dim = lig_graph.edata['e'].shape[1] if model_config.data.add_chemical_bond_feats: lig_edge_dim += lig_graph.edata['bond_type'].shape[1] pro_node_dim = prot_graph.ndata['h'].shape[1] pro_edge_dim = prot_graph.edata['e'].shape[1] inter_edge_dim = 15 if model_config.data.use_mean_node_features: lig_node_dim += 5 pro_node_dim += 5 return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim def trans_device(data, device): return [x if isinstance(x, list) else x.to(device) for x in data] def get_data(model_config, ligand_path, protein_path): molecular_representation = commons.read_molecules_inference(ligand_path, protein_path, model_config.data.prot_graph_type, model_config.data.chaincut) lig_coords, lig_features, lig_edges, lig_node_type, \ prot_coords, prot_features, prot_edges, prot_node_type, \ sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type, max_neighbors=model_config.data.lig_max_neighbors, cutoff=model_config.data.ligcut) prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type, sec_features, alpha_c_coords, c_coords, n_coords, max_neighbor=model_config.data.prot_max_neighbors, cutoff=model_config.data.protcut) prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid) prot_graph.chain_index = chain_index_valid inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords, max_neighbor=model_config.data.inter_max_neighbors, min_neighbor=model_config.data.inter_min_neighbors, cutoff=model_config.data.intercut) # set feats dim lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config) model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim model_config.model.inter_edge_dim = inter_edge_dim if model_config.data.add_chemical_bond_feats: lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1) if model_config.data.use_mean_node_features: lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1) prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1) label = torch.tensor(-100).unsqueeze(dim=-1) item = [0] assay_des = torch.zeros(0) IC50_f, K_f = [True], [True] data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f) return trans_device(data, model_config.device) def get_models(model_config, model_dir, checkpoint): if model_config.train.multi_task: model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device) else: model = globals()[model_config.model.model_type](model_config).to(model_config.device) checkpoint_path = os.path.join(model_dir, checkpoint) print("Load checkpoint from %s" % checkpoint_path) state = torch.load(checkpoint_path, map_location=model_config.device) model.load_state_dict(state["model"]) model = model.eval() return model def mbp_scoring(ligand_path, protein_path): model_config = get_config(model_dir) data_example = get_data(model_config, ligand_path, protein_path) model = get_models(model_config, model_dir, checkpoint) _, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False) return affinity_pred_IC50.item(), affinity_pred_K.item() def test(ligand, protein): try: IC50, K = mbp_scoring(ligand.name, protein.name) return '{:.2f}'.format(IC50), '{:.2f}'.format(K) except Exception as e: return 'Please set input correctly', 'Please set input correctly' with gr.Blocks() as demo: ligand = gr.File(label="ligand") protein = gr.File(label="Protein, please use .pdb files") IC50 = gr.Textbox(label="Predicted IC50 Value") K = gr.Textbox(label="Predicted K Value") submit_btn = gr.Button("Submit") submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring") demo.launch()