from datetime import datetime import gradio as gr import requests from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select from Bio.PDB.Polypeptide import is_aa from Bio.SeqUtils import seq1 from typing import Optional, Tuple import numpy as np import os from gradio_molecule3d import Molecule3D from model_loader import load_model import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import re import pandas as pd import copy import transformers from transformers import AutoTokenizer, DataCollatorForTokenClassification from datasets import Dataset from scipy.special import expit # Load model and move to device checkpoint = 'ThorbenF/prot_t5_xl_uniref50' max_length = 1500 model, tokenizer = load_model(checkpoint, max_length) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() def normalize_scores(scores): min_score = np.min(scores) max_score = np.max(scores) return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores def read_mol(pdb_path): """Read PDB file and return its content as a string""" with open(pdb_path, 'r') as f: return f.read() def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]: """ Fetch the structure file for a given PDB ID. Prioritizes CIF files. If a structure file already exists locally, it uses that. """ file_path = download_structure(pdb_id, output_dir) if file_path: return file_path else: return None def download_structure(pdb_id: str, output_dir: str) -> Optional[str]: """ Attempt to download the structure file in CIF or PDB format. Returns the path to the downloaded file, or None if download fails. """ for ext in ['.cif', '.pdb']: file_path = os.path.join(output_dir, f"{pdb_id}{ext}") if os.path.exists(file_path): return file_path url = f"https://files.rcsb.org/download/{pdb_id}{ext}" try: response = requests.get(url, timeout=10) if response.status_code == 200: with open(file_path, 'wb') as f: f.write(response.content) return file_path except Exception as e: print(f"Download error for {pdb_id}{ext}: {e}") return None def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str: """ Convert a CIF file to PDB format using BioPython and return the PDB file path. """ pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb')) parser = MMCIFParser(QUIET=True) structure = parser.get_structure('protein', cif_path) io = PDBIO() io.set_structure(structure) io.save(pdb_path) return pdb_path def fetch_pdb(pdb_id): pdb_path = fetch_structure(pdb_id) if not pdb_path: return None _, ext = os.path.splitext(pdb_path) if ext == '.cif': pdb_path = convert_cif_to_pdb(pdb_path) return pdb_path def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str: """ Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores """ # Read the original PDB file parser = PDBParser(QUIET=True) structure = parser.get_structure('protein', input_pdb) # Prepare a new structure with only the specified chain and selected residues output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb" # Create scores dictionary for easy lookup scores_dict = {resi: score for resi, score in residue_scores} # Create a custom Select class class ResidueSelector(Select): def __init__(self, chain_id, selected_residues, scores_dict): self.chain_id = chain_id self.selected_residues = selected_residues self.scores_dict = scores_dict def accept_chain(self, chain): return chain.id == self.chain_id def accept_residue(self, residue): return residue.id[1] in self.selected_residues def accept_atom(self, atom): if atom.parent.id[1] in self.scores_dict: atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100 return True # Prepare output PDB with selected chain and residues, modified B-factors io = PDBIO() selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict) io.set_structure(structure[0]) io.save(output_pdb, selector) return output_pdb def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str): """ Calculate the geometric center of high-scoring residues """ parser = PDBParser(QUIET=True) structure = parser.get_structure('protein', pdb_path) # Collect coordinates of CA atoms from high-scoring residues coords = [] for model in structure: for chain in model: if chain.id == chain_id: for residue in chain: if residue.id[1] in high_score_residues: if 'CA' in residue: # Use alpha carbon as representative ca_atom = residue['CA'] coords.append(ca_atom.coord) # Calculate geometric center if coords: center = np.mean(coords, axis=0) return center return None def process_pdb(pdb_id_or_file, segment): # Determine if input is a PDB ID or file path if pdb_id_or_file.endswith('.pdb'): pdb_path = pdb_id_or_file pdb_id = os.path.splitext(os.path.basename(pdb_path))[0] else: pdb_id = pdb_id_or_file pdb_path = fetch_pdb(pdb_id) if not pdb_path: return "Failed to fetch PDB file", None, None # Determine the file format and choose the appropriate parser _, ext = os.path.splitext(pdb_path) parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True) try: # Parse the structure file structure = parser.get_structure('protein', pdb_path) except Exception as e: return f"Error parsing structure file: {e}", None, None # Extract the specified chain try: chain = structure[0][segment] except KeyError: return "Invalid Chain ID", None, None protein_residues = [res for res in chain if is_aa(res)] sequence = "".join(seq1(res.resname) for res in protein_residues) sequence_id = [res.id[1] for res in protein_residues] input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) with torch.no_grad(): outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() # Calculate scores and normalize them scores = expit(outputs[:, 1] - outputs[:, 0]) normalized_scores = normalize_scores(scores) # Zip residues with scores to track the residue ID and score residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)] # Define the score brackets score_brackets = { "0.0-0.2": (0.0, 0.2), "0.2-0.4": (0.2, 0.4), "0.4-0.6": (0.4, 0.6), "0.6-0.8": (0.6, 0.8), "0.8-1.0": (0.8, 1.0) } # Initialize a dictionary to store residues by bracket residues_by_bracket = {bracket: [] for bracket in score_brackets} # Categorize residues into brackets for resi, score in residue_scores: for bracket, (lower, upper) in score_brackets.items(): if lower <= score < upper: residues_by_bracket[bracket].append(resi) break # Preparing the result string current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n" result_str += "Residues by Score Brackets:\n\n" # Add residues for each bracket for bracket, residues in residues_by_bracket.items(): result_str += f"Bracket {bracket}:\n" result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n" result_str += "\n".join([ f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}" for i, res in enumerate(protein_residues) if res.id[1] in residues ]) result_str += "\n\n" # Create chain-specific PDB with scores in B-factor scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues) # Molecule visualization with updated script with color mapping mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map) # Improved PyMOL command suggestions current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n" pymol_commands += f""" # PyMOL Visualization Commands load {os.path.abspath(pdb_path)}, protein hide everything, all show cartoon, chain {segment} color white, chain {segment} """ # Define colors for each score bracket bracket_colors = { "0.0-0.2": "white", "0.2-0.4": "lightorange", "0.4-0.6": "orange", "0.6-0.8": "orangered", "0.8-1.0": "red" } # Add PyMOL commands for each score bracket for bracket, residues in residues_by_bracket.items(): if residues: # Only add commands if there are residues in this bracket color = bracket_colors[bracket] resi_list = '+'.join(map(str, residues)) pymol_commands += f""" select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment} show sticks, bracket_{bracket.replace('.', '').replace('-', '_')} color {color}, bracket_{bracket.replace('.', '').replace('-', '_')} """ # Create prediction and scored PDB files prediction_file = f"{pdb_id}_binding_site_residues.txt" with open(prediction_file, "w") as f: f.write(result_str) return pymol_commands, mol_vis, [prediction_file,scored_pdb] def molecule(input_pdb, residue_scores=None, segment='A'): # More granular scoring for visualization mol = read_mol(input_pdb) # Read PDB file content # Prepare high-scoring residues script if scores are provided high_score_script = "" if residue_scores is not None: # Filter residues based on their scores class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2] class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4] class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6] class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8] class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0] high_score_script = """ // Load the original model and apply white cartoon style let chainModel = viewer.addModel(pdb, "pdb"); chainModel.setStyle({}, {}); chainModel.setStyle( {"chain": "%s"}, {"cartoon": {"color": "white"}} ); // Create a new model for high-scoring residues and apply red sticks style let class1Model = viewer.addModel(pdb, "pdb"); class1Model.setStyle({}, {}); class1Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFFFFF", "opacity": 0.5}} ); // Create a new model for high-scoring residues and apply red sticks style let class2Model = viewer.addModel(pdb, "pdb"); class2Model.setStyle({}, {}); class2Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFD580", "opacity": 0.7}} ); // Create a new model for high-scoring residues and apply red sticks style let class3Model = viewer.addModel(pdb, "pdb"); class3Model.setStyle({}, {}); class3Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFA500", "opacity": 1}} ); // Create a new model for high-scoring residues and apply red sticks style let class4Model = viewer.addModel(pdb, "pdb"); class4Model.setStyle({}, {}); class4Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFF4500", "opacity": 1}} ); // Create a new model for high-scoring residues and apply red sticks style let class5Model = viewer.addModel(pdb, "pdb"); class5Model.setStyle({}, {}); class5Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFF0000", "alpha": 1}} ); """ % ( segment, segment, ", ".join(str(resi) for resi in class1_score_residues), segment, ", ".join(str(resi) for resi in class2_score_residues), segment, ", ".join(str(resi) for resi in class3_score_residues), segment, ", ".join(str(resi) for resi in class4_score_residues), segment, ", ".join(str(resi) for resi in class5_score_residues) ) # Generate the full HTML content html_content = f"""
""" # Return the HTML content within an iframe safely encoded for special characters return f'' # Gradio UI with gr.Blocks() as demo: gr.Markdown("# Protein Binding Site Prediction") # Mode selection mode = gr.Radio( choices=["PDB ID", "Upload File"], value="PDB ID", label="Input Mode", info="Choose whether to input a PDB ID or upload a PDB/CIF file." ) # Input components based on mode pdb_input = gr.Textbox(value="4BDU", label="PDB ID", placeholder="Enter PDB ID here...") pdb_file = gr.File(label="Upload PDB/CIF File", visible=False) visualize_btn = gr.Button("Visualize Structure") molecule_output2 = Molecule3D(label="Protein Structure", reps=[ { "model": 0, "style": "cartoon", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, } ]) with gr.Row(): segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...") prediction_btn = gr.Button("Predict Binding Site") molecule_output = gr.HTML(label="Protein Structure") explanation_vis = gr.Markdown(""" Score dependent colorcoding: - 0.0-0.2: white - 0.2–0.4: light orange - 0.4–0.6: orange - 0.6–0.8: orangered - 0.8–1.0: red """) predictions_output = gr.Textbox(label="Visualize Prediction with PyMol") gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column") download_output = gr.File(label="Download Files", file_count="multiple") def process_interface(mode, pdb_id, pdb_file, chain_id): if mode == "PDB ID": return process_pdb(pdb_id, chain_id) elif mode == "Upload File": _, ext = os.path.splitext(pdb_file.name) file_path = os.path.join('./', f"{_}{ext}") if ext == '.cif': pdb_path = convert_cif_to_pdb(file_path) else: pdb_path= file_path return process_pdb(pdb_path, chain_id) else: return "Error: Invalid mode selected", None, None def fetch_interface(mode, pdb_id, pdb_file): if mode == "PDB ID": return fetch_pdb(pdb_id) elif mode == "Upload File": _, ext = os.path.splitext(pdb_file.name) file_path = os.path.join('./', f"{_}{ext}") #print(ext) if ext == '.cif': pdb_path = convert_cif_to_pdb(file_path) else: pdb_path= file_path #print(pdb_path) return pdb_path else: return "Error: Invalid mode selected" def toggle_mode(selected_mode): if selected_mode == "PDB ID": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) mode.change( toggle_mode, inputs=[mode], outputs=[pdb_input, pdb_file] ) prediction_btn.click( process_interface, inputs=[mode, pdb_input, pdb_file, segment_input], outputs=[predictions_output, molecule_output, download_output] ) visualize_btn.click( fetch_interface, inputs=[mode, pdb_input, pdb_file], outputs=molecule_output2 ) gr.Markdown("## Examples") gr.Examples( examples=[ ["7RPZ", "A"], ["2IWI", "B"], ["2F6V", "A"] ], inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output] ) demo.launch(share=True)