diff --git "a/.ipynb_checkpoints/test4-checkpoint.ipynb" "b/.ipynb_checkpoints/test4-checkpoint.ipynb" new file mode 100644--- /dev/null +++ "b/.ipynb_checkpoints/test4-checkpoint.ipynb" @@ -0,0 +1,2830 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "id": "2c614d31-f96a-4164-8293-1cec9b0b2cd0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7870\n", + "* Running on public URL: https://c83ac6a1bebcdb4528.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the specified chain and replace B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain\n", + " new_structure = structure.copy()\n", + " for model in new_structure:\n", + " # Remove all chains except the specified one\n", + " chains_to_remove = [chain for chain in model if chain.id != chain_id]\n", + " for chain in chains_to_remove:\n", + " model.detach_child(chain.id)\n", + " \n", + " # Create a modified PDB with scores in B-factor\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + " for model in new_structure:\n", + " for chain in model:\n", + " for residue in chain:\n", + " if residue.id[1] in scores_dict:\n", + " for atom in residue:\n", + " atom.bfactor = scores_dict[residue.id[1]] #* 100 # Scale score to B-factor range\n", + " \n", + " # Save the modified structure\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n", + " io = PDBIO()\n", + " io.set_structure(new_structure)\n", + " io.save(output_pdb)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " # Identify high and mid scoring residues\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", + "\n", + " # Calculate geometric center of high-scoring residues\n", + " geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n", + " pymol_selection = f\"select high_score_residues, resi {'+'.join(map(str, high_score_residues))} and chain {segment}\"\n", + " pymol_center_cmd = f\"show spheres, resi {'+'.join(map(str, high_score_residues))} and chain {segment}\" if geo_center is not None else \"\"\n", + "\n", + " # Generate the result string\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues)])\n", + " \n", + " # Create prediction and scored PDB files\n", + " prediction_file = f\"{pdb_id}_predictions.txt\"\n", + " with open(prediction_file, \"w\") as f:\n", + " f.write(result_str)\n", + "\n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores)\n", + "\n", + " # Molecule visualization with updated script\n", + " mol_vis = molecule(pdb_path, residue_scores, segment)\n", + "\n", + " # Construct PyMOL command suggestions\n", + " pymol_commands = f\"\"\"\n", + "PyMOL Visualization Commands:\n", + "1. Load PDB: load {os.path.abspath(pdb_path)}\n", + "2. Select high-scoring residues: {pymol_selection}\n", + "3. Highlight high-scoring residues: show sticks, high_score_residues\n", + "{pymol_center_cmd}\n", + "\"\"\"\n", + " \n", + " return result_str + \"\\n\\n\" + pymol_commands, mol_vis, [prediction_file, scored_pdb]\n", + "\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let highScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " highScoreModel.setStyle({}, {});\n", + " highScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " // Create a new model for medium-scoring residues and apply orange sticks style\n", + " let midScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " midScoreModel.setStyle({}, {});\n", + " midScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in high_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in mid_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " with gr.Row():\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " prediction_btn.click(\n", + " process_pdb, \n", + " inputs=[\n", + " pdb_input, \n", + " segment_input\n", + " ], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_pdb, \n", + " inputs=[pdb_input], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "db0b4763-5368-4d73-b5f6-d1c168f7fcd8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7864\n", + "* Running on public URL: https://060e61e5b829d9fb6e.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from Bio.PDB import Select\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain and selected residues\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n", + " \n", + " # Create scores dictionary for easy lookup\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + "\n", + " # Create a custom Select class\n", + " class ResidueSelector(Select):\n", + " def __init__(self, chain_id, selected_residues, scores_dict):\n", + " self.chain_id = chain_id\n", + " self.selected_residues = selected_residues\n", + " self.scores_dict = scores_dict\n", + " \n", + " def accept_chain(self, chain):\n", + " return chain.id == self.chain_id\n", + " \n", + " def accept_residue(self, residue):\n", + " return residue.id[1] in self.selected_residues\n", + "\n", + " def accept_atom(self, atom):\n", + " if atom.parent.id[1] in self.scores_dict:\n", + " atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n", + " return True\n", + "\n", + " # Prepare output PDB with selected chain and residues, modified B-factors\n", + " io = PDBIO()\n", + " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", + " \n", + " io.set_structure(structure[0])\n", + " io.save(output_pdb, selector)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " # More granular scoring for visualization\n", + " def score_to_color(score):\n", + " if score <= 0.6:\n", + " return \"blue\"\n", + " elif score <= 0.7:\n", + " return \"lightblue\"\n", + " elif score <= 0.8:\n", + " return \"white\"\n", + " elif score <= 0.9:\n", + " return \"orange\"\n", + " elif score > 0.9:\n", + " return \"red\"\n", + "\n", + " color_map = {resi: score_to_color(score) for resi, score in residue_scores}\n", + " \n", + " # Identify high scoring residues (> 0.7)\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.7]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.7]\n", + "\n", + " # Calculate geometric center of high-scoring residues\n", + " geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n", + "\n", + " # Preparing the result: only print high scoring residues\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\n\"\n", + " result_str += \"High-scoring Residues (Score > 0.7):\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n", + " ])\n", + " \n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", + "\n", + " # Molecule visualization with updated script with color mapping\n", + " mol_vis = molecule(pdb_path, residue_scores, segment) #, color_map)\n", + "\n", + " # Improved PyMOL command suggestions\n", + " pymol_commands = f\"\"\"\n", + "# PyMOL Visualization Commands\n", + "load {os.path.abspath(pdb_path)}, protein\n", + "hide everything, all\n", + "show cartoon, chain {segment}\n", + "color white, chain {segment}\n", + "\"\"\"\n", + " \n", + " # Color specific residues\n", + " for score_range, color in [\n", + " (high_score_residues, \"red\"), \n", + " (mid_score_residues, \"orange\")\n", + " ]:\n", + " if score_range:\n", + " resi_list = '+'.join(map(str, score_range))\n", + " pymol_commands += f\"\"\"\n", + "select high_score_residues, resi {resi_list} and chain {segment}\n", + "show sticks, high_score_residues\n", + "color {color}, high_score_residues\n", + "\"\"\"\n", + " \n", + " return result_str, mol_vis, [scored_pdb]\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let highScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " highScoreModel.setStyle({}, {});\n", + " highScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " // Create a new model for medium-scoring residues and apply orange sticks style\n", + " let midScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " midScoreModel.setStyle({}, {});\n", + " midScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in high_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in mid_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " with gr.Row():\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " prediction_btn.click(\n", + " process_pdb, \n", + " inputs=[\n", + " pdb_input, \n", + " segment_input\n", + " ], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_pdb, \n", + " inputs=[pdb_input], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d0d50415-1304-462d-a176-b58f394e79b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7866\n", + "* Running on public URL: https://a9ff499df0a5f7be8c.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from Bio.PDB import Select\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain and selected residues\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n", + " \n", + " # Create scores dictionary for easy lookup\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + "\n", + " # Create a custom Select class\n", + " class ResidueSelector(Select):\n", + " def __init__(self, chain_id, selected_residues, scores_dict):\n", + " self.chain_id = chain_id\n", + " self.selected_residues = selected_residues\n", + " self.scores_dict = scores_dict\n", + " \n", + " def accept_chain(self, chain):\n", + " return chain.id == self.chain_id\n", + " \n", + " def accept_residue(self, residue):\n", + " return residue.id[1] in self.selected_residues\n", + "\n", + " def accept_atom(self, atom):\n", + " if atom.parent.id[1] in self.scores_dict:\n", + " atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n", + " return True\n", + "\n", + " # Prepare output PDB with selected chain and residues, modified B-factors\n", + " io = PDBIO()\n", + " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", + " \n", + " io.set_structure(structure[0])\n", + " io.save(output_pdb, selector)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " # More granular scoring for visualization\n", + " def score_to_color(score):\n", + " if score <= 0.6:\n", + " return \"blue\"\n", + " elif score <= 0.7:\n", + " return \"lightblue\"\n", + " elif score <= 0.8:\n", + " return \"white\"\n", + " elif score <= 0.9:\n", + " return \"orange\"\n", + " elif score > 0.9:\n", + " return \"red\"\n", + "\n", + " color_map = {resi: score_to_color(score) for resi, score in residue_scores}\n", + " \n", + " # Identify high scoring residues (> 0.7)\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.7]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.7]\n", + "\n", + " # Calculate geometric center of high-scoring residues\n", + " geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n", + "\n", + " # Preparing the result: only print high scoring residues\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\n\"\n", + " result_str += \"High-scoring Residues (Score > 0.7):\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n", + " ])\n", + "\n", + " # Create prediction and scored PDB files\n", + " prediction_file = f\"{pdb_id}_predictions.txt\"\n", + " with open(prediction_file, \"w\") as f:\n", + " f.write(result_str)\n", + " \n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", + "\n", + " # Molecule visualization with updated script with color mapping\n", + " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n", + "\n", + " # Improved PyMOL command suggestions\n", + " pymol_commands = f\"\"\"\n", + "# PyMOL Visualization Commands\n", + "load {os.path.abspath(pdb_path)}, protein\n", + "hide everything, all\n", + "show cartoon, chain {segment}\n", + "color white, chain {segment}\n", + "\"\"\"\n", + " \n", + " # Color specific residues\n", + " for score_range, color in [\n", + " (high_score_residues, \"red\"), \n", + " (mid_score_residues, \"orange\")\n", + " ]:\n", + " if score_range:\n", + " resi_list = '+'.join(map(str, score_range))\n", + " pymol_commands += f\"\"\"\n", + "select high_score_residues, resi {resi_list} and chain {segment}\n", + "show sticks, high_score_residues\n", + "color {color}, high_score_residues\n", + "\"\"\"\n", + " \n", + " return result_str, mol_vis, [prediction_file,scored_pdb]\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", + " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let highScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " highScoreModel.setStyle({}, {});\n", + " highScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " // Create a new model for medium-scoring residues and apply orange sticks style\n", + " let midScoreModel = viewer.addModel(pdb, \"pdb\");\n", + " midScoreModel.setStyle({}, {});\n", + " midScoreModel.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in high_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in mid_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " with gr.Row():\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " prediction_btn.click(\n", + " process_pdb, \n", + " inputs=[\n", + " pdb_input, \n", + " segment_input\n", + " ], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_pdb, \n", + " inputs=[pdb_input], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "004ab20c-5273-44b9-bc69-d41f236296e4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7890\n", + "* Running on public URL: https://a7f63d297aa65a70de.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from Bio.PDB import Select\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain and selected residues\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n", + " \n", + " # Create scores dictionary for easy lookup\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + "\n", + " # Create a custom Select class\n", + " class ResidueSelector(Select):\n", + " def __init__(self, chain_id, selected_residues, scores_dict):\n", + " self.chain_id = chain_id\n", + " self.selected_residues = selected_residues\n", + " self.scores_dict = scores_dict\n", + " \n", + " def accept_chain(self, chain):\n", + " return chain.id == self.chain_id\n", + " \n", + " def accept_residue(self, residue):\n", + " return residue.id[1] in self.selected_residues\n", + "\n", + " def accept_atom(self, atom):\n", + " if atom.parent.id[1] in self.scores_dict:\n", + " atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n", + " return True\n", + "\n", + " # Prepare output PDB with selected chain and residues, modified B-factors\n", + " io = PDBIO()\n", + " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", + " \n", + " io.set_structure(structure[0])\n", + " io.save(output_pdb, selector)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " \n", + " # Identify high scoring residues (> 0.5)\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n", + " \n", + " # Preparing the result: only print high scoring residues\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n", + " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n", + " ])\n", + "\n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", + "\n", + " # Molecule visualization with updated script with color mapping\n", + " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n", + "\n", + " # Improved PyMOL command suggestions\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " \n", + " pymol_commands += f\"\"\"\n", + "# PyMOL Visualization Commands\n", + "load {os.path.abspath(pdb_path)}, protein\n", + "hide everything, all\n", + "show cartoon, chain {segment}\n", + "color white, chain {segment}\n", + "\"\"\"\n", + " \n", + " # Color specific residues\n", + " for score_range, color in [\n", + " (high_score_residues, \"red\")\n", + " ]:\n", + " if score_range:\n", + " resi_list = '+'.join(map(str, score_range))\n", + " pymol_commands += f\"\"\"\n", + "select high_score_residues, resi {resi_list} and chain {segment}\n", + "show sticks, high_score_residues\n", + "color {color}, high_score_residues\n", + "\"\"\"\n", + " # Create prediction and scored PDB files\n", + " prediction_file = f\"{pdb_id}_predictions.txt\"\n", + " with open(prediction_file, \"w\") as f:\n", + " f.write(result_str)\n", + " \n", + " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " # More granular scoring for visualization\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n", + " class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n", + " class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n", + " class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n", + " class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class1Model = viewer.addModel(pdb, \"pdb\");\n", + " class1Model.setStyle({}, {});\n", + " class1Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"blue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class2Model = viewer.addModel(pdb, \"pdb\");\n", + " class2Model.setStyle({}, {});\n", + " class2Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"lightblue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class3Model = viewer.addModel(pdb, \"pdb\");\n", + " class3Model.setStyle({}, {});\n", + " class3Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class4Model = viewer.addModel(pdb, \"pdb\");\n", + " class4Model.setStyle({}, {});\n", + " class4Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class5Model = viewer.addModel(pdb, \"pdb\");\n", + " class5Model.setStyle({}, {});\n", + " class5Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in class1_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class2_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class3_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class4_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class5_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " with gr.Row():\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " prediction_btn.click(\n", + " process_pdb, \n", + " inputs=[\n", + " pdb_input, \n", + " segment_input\n", + " ], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_pdb, \n", + " inputs=[pdb_input], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "a492c5c5-e0aa-4445-9375-64cfdb963e04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7891\n", + "* Running on public URL: https://339346b4ad32f608d0.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from Bio.PDB import Select\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain and selected residues\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n", + " \n", + " # Create scores dictionary for easy lookup\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + "\n", + " # Create a custom Select class\n", + " class ResidueSelector(Select):\n", + " def __init__(self, chain_id, selected_residues, scores_dict):\n", + " self.chain_id = chain_id\n", + " self.selected_residues = selected_residues\n", + " self.scores_dict = scores_dict\n", + " \n", + " def accept_chain(self, chain):\n", + " return chain.id == self.chain_id\n", + " \n", + " def accept_residue(self, residue):\n", + " return residue.id[1] in self.selected_residues\n", + "\n", + " def accept_atom(self, atom):\n", + " if atom.parent.id[1] in self.scores_dict:\n", + " atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n", + " return True\n", + "\n", + " # Prepare output PDB with selected chain and residues, modified B-factors\n", + " io = PDBIO()\n", + " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", + " \n", + " io.set_structure(structure[0])\n", + " io.save(output_pdb, selector)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " \n", + " # Identify high scoring residues (> 0.5)\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n", + " \n", + " # Preparing the result: only print high scoring residues\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n", + " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n", + " ])\n", + "\n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", + "\n", + " # Molecule visualization with updated script with color mapping\n", + " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n", + "\n", + " # Improved PyMOL command suggestions\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " \n", + " pymol_commands += f\"\"\"\n", + "# PyMOL Visualization Commands\n", + "load {os.path.abspath(pdb_path)}, protein\n", + "hide everything, all\n", + "show cartoon, chain {segment}\n", + "color white, chain {segment}\n", + "\"\"\"\n", + " \n", + " # Color specific residues\n", + " for score_range, color in [\n", + " (high_score_residues, \"red\")\n", + " ]:\n", + " if score_range:\n", + " resi_list = '+'.join(map(str, score_range))\n", + " pymol_commands += f\"\"\"\n", + "select high_score_residues, resi {resi_list} and chain {segment}\n", + "show sticks, high_score_residues\n", + "color {color}, high_score_residues\n", + "\"\"\"\n", + " # Create prediction and scored PDB files\n", + " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n", + " with open(prediction_file, \"w\") as f:\n", + " f.write(result_str)\n", + " \n", + " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " # More granular scoring for visualization\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n", + " class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n", + " class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n", + " class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n", + " class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class1Model = viewer.addModel(pdb, \"pdb\");\n", + " class1Model.setStyle({}, {});\n", + " class1Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"blue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class2Model = viewer.addModel(pdb, \"pdb\");\n", + " class2Model.setStyle({}, {});\n", + " class2Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"lightblue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class3Model = viewer.addModel(pdb, \"pdb\");\n", + " class3Model.setStyle({}, {});\n", + " class3Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class4Model = viewer.addModel(pdb, \"pdb\");\n", + " class4Model.setStyle({}, {});\n", + " class4Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class5Model = viewer.addModel(pdb, \"pdb\");\n", + " class5Model.setStyle({}, {});\n", + " class5Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in class1_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class2_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class3_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class4_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class5_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " with gr.Row():\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " explanation_vis = gr.Markdown(\"\"\"\n", + " Residues with a score > 0.5 are considered binding sites and represented as sticks with the score dependent colorcoding:\n", + " - 0.5-0.6: blue \n", + " - 0.6–0.7: light blue \n", + " - 0.7–0.8: white\n", + " - 0.8–0.9: orange\n", + " - 0.9–1.0: red\n", + " \"\"\")\n", + " predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n", + " gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " prediction_btn.click(\n", + " process_pdb, \n", + " inputs=[\n", + " pdb_input, \n", + " segment_input\n", + " ], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_pdb, \n", + " inputs=[pdb_input], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "99d18e7c-3ec1-48f2-b368-958b66bb1782", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7923\n", + "* Running on public URL: https://ad5916147a5fd9c4b5.gradio.live\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ".cif\n", + "./7RPZ.pdb\n", + ".pdb\n", + "/private/var/folders/tm/ym2tckv54b96ws82y3b7cqhh0000gn/T/gradio/6b7bd3b706f978096c02bacdbf7b38529f0a5233f7570f758063b6e78f62771d/2F6V.pdb\n" + ] + } + ], + "source": [ + "from datetime import datetime\n", + "import gradio as gr\n", + "import requests\n", + "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n", + "from Bio.PDB.Polypeptide import is_aa\n", + "from Bio.SeqUtils import seq1\n", + "from Bio.PDB import Select\n", + "from typing import Optional, Tuple\n", + "import numpy as np\n", + "import os\n", + "from gradio_molecule3d import Molecule3D\n", + "\n", + "import re\n", + "import pandas as pd\n", + "import copy\n", + "\n", + "from scipy.special import expit\n", + "\n", + "def normalize_scores(scores):\n", + " min_score = np.min(scores)\n", + " max_score = np.max(scores)\n", + " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", + "\n", + "def read_mol(pdb_path):\n", + " \"\"\"Read PDB file and return its content as a string\"\"\"\n", + " with open(pdb_path, 'r') as f:\n", + " return f.read()\n", + "\n", + "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", + " \"\"\"\n", + " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", + " If a structure file already exists locally, it uses that.\n", + " \"\"\"\n", + " file_path = download_structure(pdb_id, output_dir)\n", + " if file_path:\n", + " return file_path\n", + " else:\n", + " return None\n", + "\n", + "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", + " \"\"\"\n", + " Attempt to download the structure file in CIF or PDB format.\n", + " Returns the path to the downloaded file, or None if download fails.\n", + " \"\"\"\n", + " for ext in ['.cif', '.pdb']:\n", + " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", + " if os.path.exists(file_path):\n", + " return file_path\n", + " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", + " try:\n", + " response = requests.get(url, timeout=10)\n", + " if response.status_code == 200:\n", + " with open(file_path, 'wb') as f:\n", + " f.write(response.content)\n", + " return file_path\n", + " except Exception as e:\n", + " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", + " return None\n", + "\n", + "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", + " \"\"\"\n", + " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", + " \"\"\"\n", + " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", + " parser = MMCIFParser(QUIET=True)\n", + " structure = parser.get_structure('protein', cif_path)\n", + " io = PDBIO()\n", + " io.set_structure(structure)\n", + " io.save(pdb_path)\n", + " return pdb_path\n", + "\n", + "def fetch_pdb(pdb_id):\n", + " pdb_path = fetch_structure(pdb_id)\n", + " if not pdb_path:\n", + " return None\n", + " _, ext = os.path.splitext(pdb_path)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(pdb_path)\n", + " return pdb_path\n", + "\n", + "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", + " \"\"\"\n", + " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", + " \"\"\"\n", + " # Read the original PDB file\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', input_pdb)\n", + " \n", + " # Prepare a new structure with only the specified chain and selected residues\n", + " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n", + " \n", + " # Create scores dictionary for easy lookup\n", + " scores_dict = {resi: score for resi, score in residue_scores}\n", + "\n", + " # Create a custom Select class\n", + " class ResidueSelector(Select):\n", + " def __init__(self, chain_id, selected_residues, scores_dict):\n", + " self.chain_id = chain_id\n", + " self.selected_residues = selected_residues\n", + " self.scores_dict = scores_dict\n", + " \n", + " def accept_chain(self, chain):\n", + " return chain.id == self.chain_id\n", + " \n", + " def accept_residue(self, residue):\n", + " return residue.id[1] in self.selected_residues\n", + "\n", + " def accept_atom(self, atom):\n", + " if atom.parent.id[1] in self.scores_dict:\n", + " atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n", + " return True\n", + "\n", + " # Prepare output PDB with selected chain and residues, modified B-factors\n", + " io = PDBIO()\n", + " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", + " \n", + " io.set_structure(structure[0])\n", + " io.save(output_pdb, selector)\n", + " \n", + " return output_pdb\n", + "\n", + "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n", + " \"\"\"\n", + " Calculate the geometric center of high-scoring residues\n", + " \"\"\"\n", + " parser = PDBParser(QUIET=True)\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " \n", + " # Collect coordinates of CA atoms from high-scoring residues\n", + " coords = []\n", + " for model in structure:\n", + " for chain in model:\n", + " if chain.id == chain_id:\n", + " for residue in chain:\n", + " if residue.id[1] in high_score_residues:\n", + " if 'CA' in residue: # Use alpha carbon as representative\n", + " ca_atom = residue['CA']\n", + " coords.append(ca_atom.coord)\n", + " \n", + " # Calculate geometric center\n", + " if coords:\n", + " center = np.mean(coords, axis=0)\n", + " return center\n", + " return None\n", + "\n", + "def process_pdb(pdb_id_or_file, segment):\n", + " # Determine if input is a PDB ID or file path\n", + " if pdb_id_or_file.endswith('.pdb'):\n", + " pdb_path = pdb_id_or_file\n", + " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", + " else:\n", + " pdb_id = pdb_id_or_file\n", + " pdb_path = fetch_pdb(pdb_id)\n", + " \n", + " if not pdb_path:\n", + " return \"Failed to fetch PDB file\", None, None\n", + " \n", + " # Determine the file format and choose the appropriate parser\n", + " _, ext = os.path.splitext(pdb_path)\n", + " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", + " \n", + " try:\n", + " # Parse the structure file\n", + " structure = parser.get_structure('protein', pdb_path)\n", + " except Exception as e:\n", + " return f\"Error parsing structure file: {e}\", None, None\n", + " \n", + " # Extract the specified chain\n", + " try:\n", + " chain = structure[0][segment]\n", + " except KeyError:\n", + " return \"Invalid Chain ID\", None, None\n", + " \n", + " protein_residues = [res for res in chain if is_aa(res)]\n", + " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", + " sequence_id = [res.id[1] for res in protein_residues]\n", + " \n", + " scores = np.random.rand(len(sequence))\n", + " normalized_scores = normalize_scores(scores)\n", + " \n", + " # Zip residues with scores to track the residue ID and score\n", + " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", + "\n", + " \n", + " # Identify high scoring residues (> 0.5)\n", + " high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n", + " \n", + " # Preparing the result: only print high scoring residues\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n", + " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n", + " result_str += \"\\n\".join([\n", + " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", + " for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n", + " ])\n", + "\n", + " # Create chain-specific PDB with scores in B-factor\n", + " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", + "\n", + " # Molecule visualization with updated script with color mapping\n", + " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n", + "\n", + " # Improved PyMOL command suggestions\n", + " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", + " \n", + " pymol_commands += f\"\"\"\n", + "# PyMOL Visualization Commands\n", + "load {os.path.abspath(pdb_path)}, protein\n", + "hide everything, all\n", + "show cartoon, chain {segment}\n", + "color white, chain {segment}\n", + "\"\"\"\n", + " \n", + " # Color specific residues\n", + " for score_range, color in [\n", + " (high_score_residues, \"red\")\n", + " ]:\n", + " if score_range:\n", + " resi_list = '+'.join(map(str, score_range))\n", + " pymol_commands += f\"\"\"\n", + "select high_score_residues, resi {resi_list} and chain {segment}\n", + "show sticks, high_score_residues\n", + "color {color}, high_score_residues\n", + "\"\"\"\n", + " # Create prediction and scored PDB files\n", + " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n", + " with open(prediction_file, \"w\") as f:\n", + " f.write(result_str)\n", + " \n", + " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n", + "\n", + "def molecule(input_pdb, residue_scores=None, segment='A'):\n", + " # More granular scoring for visualization\n", + " mol = read_mol(input_pdb) # Read PDB file content\n", + "\n", + " # Prepare high-scoring residues script if scores are provided\n", + " high_score_script = \"\"\n", + " if residue_scores is not None:\n", + " # Filter residues based on their scores\n", + " class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n", + " class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n", + " class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n", + " class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n", + " class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n", + " \n", + " high_score_script = \"\"\"\n", + " // Load the original model and apply white cartoon style\n", + " let chainModel = viewer.addModel(pdb, \"pdb\");\n", + " chainModel.setStyle({}, {});\n", + " chainModel.setStyle(\n", + " {\"chain\": \"%s\"}, \n", + " {\"cartoon\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class1Model = viewer.addModel(pdb, \"pdb\");\n", + " class1Model.setStyle({}, {});\n", + " class1Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"blue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class2Model = viewer.addModel(pdb, \"pdb\");\n", + " class2Model.setStyle({}, {});\n", + " class2Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"lightblue\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class3Model = viewer.addModel(pdb, \"pdb\");\n", + " class3Model.setStyle({}, {});\n", + " class3Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"white\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class4Model = viewer.addModel(pdb, \"pdb\");\n", + " class4Model.setStyle({}, {});\n", + " class4Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"orange\"}}\n", + " );\n", + "\n", + " // Create a new model for high-scoring residues and apply red sticks style\n", + " let class5Model = viewer.addModel(pdb, \"pdb\");\n", + " class5Model.setStyle({}, {});\n", + " class5Model.setStyle(\n", + " {\"chain\": \"%s\", \"resi\": [%s]}, \n", + " {\"stick\": {\"color\": \"red\"}}\n", + " );\n", + "\n", + " \"\"\" % (\n", + " segment,\n", + " segment,\n", + " \", \".join(str(resi) for resi in class1_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class2_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class3_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class4_score_residues),\n", + " segment,\n", + " \", \".join(str(resi) for resi in class5_score_residues)\n", + " )\n", + " \n", + " # Generate the full HTML content\n", + " html_content = f\"\"\"\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \"\"\"\n", + " \n", + " # Return the HTML content within an iframe safely encoded for special characters\n", + " return f''\n", + "\n", + "# Gradio UI\n", + "with gr.Blocks() as demo:\n", + " gr.Markdown(\"# Protein Binding Site Prediction\")\n", + " \n", + " # Mode selection\n", + " mode = gr.Radio(\n", + " choices=[\"PDB ID\", \"Upload File\"],\n", + " value=\"PDB ID\",\n", + " label=\"Input Mode\",\n", + " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file.\"\n", + " )\n", + "\n", + " # Input components based on mode\n", + " pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", + " pdb_file = gr.File(label=\"Upload PDB/CIF File\", visible=False)\n", + " visualize_btn = gr.Button(\"Visualize Structure\")\n", + "\n", + " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", + " {\n", + " \"model\": 0,\n", + " \"style\": \"cartoon\",\n", + " \"color\": \"whiteCarbon\",\n", + " \"residue_range\": \"\",\n", + " \"around\": 0,\n", + " \"byres\": False,\n", + " }\n", + " ])\n", + "\n", + " with gr.Row():\n", + " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", + " prediction_btn = gr.Button(\"Predict Binding Site\")\n", + "\n", + " molecule_output = gr.HTML(label=\"Protein Structure\")\n", + " explanation_vis = gr.Markdown(\"\"\"\n", + " Residues with a score > 0.5 are considered binding sites and represented as sticks with the score dependent colorcoding:\n", + " - 0.5-0.6: blue \n", + " - 0.6–0.7: light blue \n", + " - 0.7–0.8: white\n", + " - 0.8–0.9: orange\n", + " - 0.9–1.0: red\n", + " \"\"\")\n", + " predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n", + " gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n", + " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", + " \n", + " def process_interface(mode, pdb_id, pdb_file, chain_id):\n", + " if mode == \"PDB ID\":\n", + " return process_pdb(pdb_id, chain_id)\n", + " elif mode == \"Upload File\":\n", + " _, ext = os.path.splitext(pdb_file.name)\n", + " file_path = os.path.join('./', f\"{_}{ext}\")\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(file_path)\n", + " else:\n", + " pdb_path= file_path\n", + " return process_pdb(pdb_path, chain_id)\n", + " else:\n", + " return \"Error: Invalid mode selected\", None, None\n", + "\n", + " def fetch_interface(mode, pdb_id, pdb_file):\n", + " if mode == \"PDB ID\":\n", + " return fetch_pdb(pdb_id)\n", + " elif mode == \"Upload File\":\n", + " _, ext = os.path.splitext(pdb_file.name)\n", + " file_path = os.path.join('./', f\"{_}{ext}\")\n", + " #print(ext)\n", + " if ext == '.cif':\n", + " pdb_path = convert_cif_to_pdb(file_path)\n", + " else:\n", + " pdb_path= file_path\n", + " #print(pdb_path)\n", + " return pdb_path\n", + " else:\n", + " return \"Error: Invalid mode selected\"\n", + "\n", + " def toggle_mode(selected_mode):\n", + " if selected_mode == \"PDB ID\":\n", + " return gr.update(visible=True), gr.update(visible=False)\n", + " else:\n", + " return gr.update(visible=False), gr.update(visible=True)\n", + "\n", + " mode.change(\n", + " toggle_mode,\n", + " inputs=[mode],\n", + " outputs=[pdb_input, pdb_file]\n", + " )\n", + "\n", + " prediction_btn.click(\n", + " process_interface, \n", + " inputs=[mode, pdb_input, pdb_file, segment_input], \n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + " visualize_btn.click(\n", + " fetch_interface, \n", + " inputs=[mode, pdb_input, pdb_file], \n", + " outputs=molecule_output2\n", + " )\n", + "\n", + " gr.Markdown(\"## Examples\")\n", + " gr.Examples(\n", + " examples=[\n", + " [\"7RPZ\", \"A\"],\n", + " [\"2IWI\", \"B\"],\n", + " [\"2F6V\", \"A\"]\n", + " ],\n", + " inputs=[pdb_input, segment_input],\n", + " outputs=[predictions_output, molecule_output, download_output]\n", + " )\n", + "\n", + "demo.launch(share=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9496cf4f-9e5f-4b0b-bb0d-7aebbb748ae6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (LLM)", + "language": "python", + "name": "llm" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}