from datetime import datetime import gradio as gr import requests from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select from Bio.PDB.Polypeptide import is_aa from Bio.SeqUtils import seq1 from typing import Optional, Tuple import numpy as np import os from gradio_molecule3d import Molecule3D from model_loader import load_model import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import re import pandas as pd import copy import transformers from transformers import AutoTokenizer, DataCollatorForTokenClassification from datasets import Dataset from scipy.special import expit # Load model and move to device #checkpoint = 'ThorbenF/prot_t5_xl_uniref50' #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic' #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database' checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full' max_length = 1500 model, tokenizer = load_model(checkpoint, max_length) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() def normalize_scores(scores): min_score = np.min(scores) max_score = np.max(scores) return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores def read_mol(pdb_path): """Read PDB file and return its content as a string""" with open(pdb_path, 'r') as f: return f.read() def fetch_structure(pdb_id: str, output_dir: str = ".") -> str: """ Fetch the structure file for a given PDB ID. Prioritizes CIF files. If a structure file already exists locally, it uses that. """ file_path = download_structure(pdb_id, output_dir) return file_path def download_structure(pdb_id: str, output_dir: str) -> str: """ Attempt to download the structure file in CIF or PDB format. Returns the path to the downloaded file. """ for ext in ['.cif', '.pdb']: file_path = os.path.join(output_dir, f"{pdb_id}{ext}") if os.path.exists(file_path): return file_path url = f"https://files.rcsb.org/download/{pdb_id}{ext}" response = requests.get(url, timeout=10) if response.status_code == 200: with open(file_path, 'wb') as f: f.write(response.content) return file_path return None def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str: """ Convert a CIF file to PDB format using BioPython and return the PDB file path. """ pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb')) parser = MMCIFParser(QUIET=True) structure = parser.get_structure('protein', cif_path) io = PDBIO() io.set_structure(structure) io.save(pdb_path) return pdb_path def fetch_pdb(pdb_id): pdb_path = fetch_structure(pdb_id) _, ext = os.path.splitext(pdb_path) if ext == '.cif': pdb_path = convert_cif_to_pdb(pdb_path) return pdb_path def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str: """ Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores """ parser = PDBParser(QUIET=True) structure = parser.get_structure('protein', input_pdb) output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb" # Create scores dictionary for easy lookup scores_dict = {resi: score for resi, score in residue_scores} # Create a custom Select class class ResidueSelector(Select): def __init__(self, chain_id, selected_residues, scores_dict): self.chain_id = chain_id self.selected_residues = selected_residues self.scores_dict = scores_dict def accept_chain(self, chain): return chain.id == self.chain_id def accept_residue(self, residue): return residue.id[1] in self.selected_residues def accept_atom(self, atom): if atom.parent.id[1] in self.scores_dict: atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100 return True # Prepare output PDB with selected chain and residues, modified B-factors io = PDBIO() selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict) io.set_structure(structure[0]) io.save(output_pdb, selector) return output_pdb def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type): """Generate PyMOL commands based on score type""" pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n" pymol_commands += f""" # PyMOL Visualization Commands fetch {pdb_id}, protein hide everything, all show cartoon, chain {segment} color white, chain {segment} """ # Define colors for each score bracket bracket_colors = { "0.0-0.2": "white", "0.2-0.4": "lightorange", "0.4-0.6": "yelloworange", "0.6-0.8": "orange", "0.8-1.0": "red" } # Add PyMOL commands for each score bracket for bracket, residues in residues_by_bracket.items(): if residues: # Only add commands if there are residues in this bracket color = bracket_colors[bracket] resi_list = '+'.join(map(str, residues)) pymol_commands += f""" select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment} show sticks, bracket_{bracket.replace('.', '').replace('-', '_')} color {color}, bracket_{bracket.replace('.', '').replace('-', '_')} """ return pymol_commands def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type): """Generate results text based on score type""" result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n" result_str += "Residues by Score Brackets:\n\n" # Add residues for each bracket for bracket, residues in residues_by_bracket.items(): result_str += f"Bracket {bracket}:\n" result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n" result_str += "\n".join([ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}" for i, res in enumerate(protein_residues) if res.id[1] in residues ]) result_str += "\n\n" return result_str def process_pdb(pdb_id_or_file, segment, score_type='normalized'): # Determine if input is a PDB ID or file path if pdb_id_or_file.endswith('.pdb'): pdb_path = pdb_id_or_file pdb_id = os.path.splitext(os.path.basename(pdb_path))[0] else: pdb_id = pdb_id_or_file pdb_path = fetch_pdb(pdb_id) # Determine the file format and choose the appropriate parser _, ext = os.path.splitext(pdb_path) parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True) # Parse the structure file structure = parser.get_structure('protein', pdb_path) # Extract the specified chain chain = structure[0][segment] protein_residues = [res for res in chain if is_aa(res)] sequence = "".join(seq1(res.resname) for res in protein_residues) sequence_id = [res.id[1] for res in protein_residues] input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) with torch.no_grad(): outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() # Calculate scores and normalize them raw_scores = expit(outputs[:, 1] - outputs[:, 0]) normalized_scores = normalize_scores(raw_scores) # Choose which scores to use based on score_type display_scores = normalized_scores if score_type == 'normalized' else raw_scores # Zip residues with scores to track the residue ID and score residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)] # Also save both score types for later use raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)] norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)] # Define the score brackets score_brackets = { "0.0-0.2": (0.0, 0.2), "0.2-0.4": (0.2, 0.4), "0.4-0.6": (0.4, 0.6), "0.6-0.8": (0.6, 0.8), "0.8-1.0": (0.8, 1.0) } # Initialize a dictionary to store residues by bracket residues_by_bracket = {bracket: [] for bracket in score_brackets} # Categorize residues into brackets for resi, score in residue_scores: for bracket, (lower, upper) in score_brackets.items(): if lower <= score < upper: residues_by_bracket[bracket].append(resi) break # Generate timestamp current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Generate result text and PyMOL commands based on score type display_score_type = "Normalized" if score_type == 'normalized' else "Raw" result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, display_scores, current_time, display_score_type) pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type) # Create chain-specific PDB with scores in B-factor scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues) # Molecule visualization with updated script with color mapping mol_vis = molecule(pdb_path, residue_scores, segment) # Create prediction file prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt" with open(prediction_file, "w") as f: f.write(result_str) scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb" os.rename(scored_pdb, scored_pdb_name) return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment def molecule(input_pdb, residue_scores=None, segment='A'): # Read PDB file content mol = read_mol(input_pdb) # Prepare high-scoring residues script if scores are provided high_score_script = "" if residue_scores is not None: # Filter residues based on their scores class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2] class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4] class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6] class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8] class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0] high_score_script = """ // Load the original model and apply white cartoon style let chainModel = viewer.addModel(pdb, "pdb"); chainModel.setStyle({}, {}); chainModel.setStyle( {"chain": "%s"}, {"cartoon": {"color": "white"}} ); // Create a new model for high-scoring residues and apply red sticks style let class1Model = viewer.addModel(pdb, "pdb"); class1Model.setStyle({}, {}); class1Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFFFFF", "opacity": 0.5}} ); // Create a new model for high-scoring residues and apply red sticks style let class2Model = viewer.addModel(pdb, "pdb"); class2Model.setStyle({}, {}); class2Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFD580", "opacity": 0.7}} ); // Create a new model for high-scoring residues and apply red sticks style let class3Model = viewer.addModel(pdb, "pdb"); class3Model.setStyle({}, {}); class3Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFFA500", "opacity": 1}} ); // Create a new model for high-scoring residues and apply red sticks style let class4Model = viewer.addModel(pdb, "pdb"); class4Model.setStyle({}, {}); class4Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFF4500", "opacity": 1}} ); // Create a new model for high-scoring residues and apply red sticks style let class5Model = viewer.addModel(pdb, "pdb"); class5Model.setStyle({}, {}); class5Model.setStyle( {"chain": "%s", "resi": [%s]}, {"stick": {"color": "0xFF0000", "alpha": 1}} ); """ % ( segment, segment, ", ".join(str(resi) for resi in class1_score_residues), segment, ", ".join(str(resi) for resi in class2_score_residues), segment, ", ".join(str(resi) for resi in class3_score_residues), segment, ", ".join(str(resi) for resi in class4_score_residues), segment, ", ".join(str(resi) for resi in class5_score_residues) ) # Generate the full HTML content html_content = f"""
""" # Return the HTML content within an iframe safely encoded for special characters return f'' with gr.Blocks(css=""" /* Customize Gradio button colors */ #visualize-btn, #predict-btn { background-color: #FF7300; /* Deep orange */ color: white; border-radius: 5px; padding: 10px; font-weight: bold; } #visualize-btn:hover, #predict-btn:hover { background-color: #CC5C00; /* Darkened orange on hover */ } """) as demo: gr.Markdown("# Protein Binding Site Prediction") # Mode selection mode = gr.Radio( choices=["PDB ID", "Upload File"], value="PDB ID", label="Input Mode", info="Choose whether to input a PDB ID or upload a PDB/CIF file." ) # Input components based on mode pdb_input = gr.Textbox(value="2F6V", label="PDB ID", placeholder="Enter PDB ID here...") pdb_file = gr.File(label="Upload PDB/CIF File", visible=False) visualize_btn = gr.Button("Visualize Structure", elem_id="visualize-btn") molecule_output2 = Molecule3D(label="Protein Structure", reps=[ { "model": 0, "style": "cartoon", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, } ]) with gr.Row(): segment_input = gr.Textbox(value="A", label="Chain ID (protein)", placeholder="Enter Chain ID here...", info="Choose in which chain to predict binding sites.") prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn") # Add score type selector score_type = gr.Radio( choices=["Normalized Scores", "Raw Scores"], value="Normalized Scores", label="Score Visualization Type", info="Choose which score type to visualize" ) molecule_output = gr.HTML(label="Protein Structure") explanation_vis = gr.Markdown(""" Score dependent colorcoding: - 0.0-0.2: white - 0.2–0.4: light orange - 0.4–0.6: yellow orange - 0.6–0.8: orange - 0.8–1.0: red """) predictions_output = gr.Textbox(label="Visualize Prediction with PyMol") gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column") download_output = gr.File(label="Download Files", file_count="multiple") # Store these as state variables so we can switch between them raw_scores_state = gr.State(None) norm_scores_state = gr.State(None) last_pdb_path = gr.State(None) last_segment = gr.State(None) last_pdb_id = gr.State(None) def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val): selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw' # First get the actual PDB file path if mode == "PDB ID": pdb_path = fetch_pdb(pdb_id) # Get the actual file path pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type) # Store the actual file path, not just the PDB ID return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result elif mode == "Upload File": _, ext = os.path.splitext(pdb_file.name) file_path = os.path.join('./', f"{_}{ext}") if ext == '.cif': pdb_path = convert_cif_to_pdb(file_path) else: pdb_path = file_path pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type) return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id): if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None: return None, None, None # Choose scores based on radio button selection selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw' selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores # Generate visualization with selected scores mol_vis = molecule(pdb_path, selected_scores, segment) # Generate PyMOL commands and downloadable files # Get structure for residue info _, ext = os.path.splitext(pdb_path) parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True) structure = parser.get_structure('protein', pdb_path) chain = structure[0][segment] protein_residues = [res for res in chain if is_aa(res)] sequence = "".join(seq1(res.resname) for res in protein_residues) # Define score brackets score_brackets = { "0.0-0.2": (0.0, 0.2), "0.2-0.4": (0.2, 0.4), "0.4-0.6": (0.4, 0.6), "0.6-0.8": (0.6, 0.8), "0.8-1.0": (0.8, 1.0) } # Initialize a dictionary to store residues by bracket residues_by_bracket = {bracket: [] for bracket in score_brackets} # Categorize residues into brackets for resi, score in selected_scores: for bracket, (lower, upper) in score_brackets.items(): if lower <= score < upper: residues_by_bracket[bracket].append(resi) break # Generate timestamp current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Generate result text and PyMOL commands based on score type display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw" scores_array = [score for _, score in selected_scores] result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores_array, current_time, display_score_type) pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type) # Create chain-specific PDB with scores in B-factor scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues) # Create prediction file prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt" with open(prediction_file, "w") as f: f.write(result_str) scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb" os.rename(scored_pdb, scored_pdb_name) return mol_vis, pymol_commands, [prediction_file, scored_pdb_name] def fetch_interface(mode, pdb_id, pdb_file): if mode == "PDB ID": return fetch_pdb(pdb_id) elif mode == "Upload File": _, ext = os.path.splitext(pdb_file.name) file_path = os.path.join('./', f"{_}{ext}") if ext == '.cif': pdb_path = convert_cif_to_pdb(file_path) else: pdb_path= file_path return pdb_path def toggle_mode(selected_mode): if selected_mode == "PDB ID": return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) mode.change( toggle_mode, inputs=[mode], outputs=[pdb_input, pdb_file] ) prediction_btn.click( process_interface, inputs=[mode, pdb_input, pdb_file, segment_input, score_type], outputs=[predictions_output, molecule_output, download_output, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id] ) # Update visualization, PyMOL commands, and files when score type changes score_type.change( update_visualization_and_files, inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id], outputs=[molecule_output, predictions_output, download_output] ) visualize_btn.click( fetch_interface, inputs=[mode, pdb_input, pdb_file], outputs=molecule_output2 ) gr.Markdown("## Examples") gr.Examples( examples=[ ["7RPZ", "A"], ["2IWI", "B"], ["7LCJ", "R"], ["4OBE", "A"] ], inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output] ) demo.launch(share=True)