{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "f3b7f6b0-6685-4a5c-9529-45e0ca905a3b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": 6, "id": "28f8f28c-48d3-4e35-9766-3de9882179b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7864\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "517a2fe7-419f-4d0b-a9ed-62a22c1c1284", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = [\n", " residue for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " ]\n", " \n", " random_scores = np.random.rand(len(sequence))\n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n", " for res, score in zip(sequence, random_scores)\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if scores is not None:\n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues2 = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n", " segment,\n", " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"2IWI\", \"A\"],\n", " [\"7RPZ\", \"B\"],\n", " [\"3TJN\", \"C\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": 6, "id": "30f35243-852f-4771-9a4b-5cdd198552b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7865\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " # Comprehensive amino acid mapping\n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues and create a list of (resi, score) pairs\n", " sequence = [\n", " (res.id[1], res) for res in chain\n", " if res.get_resname().strip() in aa_dict\n", " ]\n", "\n", " random_scores = np.random.rand(len(sequence))\n", " \n", " # Zip residues with scores to track the residue ID and score\n", " residue_scores = [(resi, score) for (resi, _), score in zip(sequence, random_scores)]\n", " \n", " result_str = \"\\n\".join(\n", " f\"{aa_dict[chain[resi].get_resname()]} {resi} {score:.2f}\"\n", " for resi, score in residue_scores\n", " )\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, residue_scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if residue_scores is not None:\n", " # Sort residues based on their scores\n", " high_score_residues = [resi for resi, score in residue_scores if score > 0.9]\n", " mid_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n", " \n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight medium-scoring residues only for the selected chain\n", " let midScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(resi) for resi in high_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in mid_score_residues),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " #pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"7RPZ\", \"A\"],\n", " [\"2IWI\", \"B\"],\n", " [\"2F6V\", \"A\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "6f17feec-0347-4f9d-acd4-ae681c3ed425", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "63201f38-adde-4b12-a8d3-f23474d045cf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5ccbf398-5ef2-4955-98db-99f904f8daa4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "\n", "from model_loader import load_model\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import pandas as pd\n", "import copy\n", "\n", "import transformers, datasets\n", "from transformers import AutoTokenizer\n", "from transformers import DataCollatorForTokenClassification\n", "\n", "from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "# Load model and move to device\n", "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n", "max_length = 1500\n", "model, tokenizer = load_model(checkpoint, max_length)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "model.eval()\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", " \n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " \n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = \"\".join(\n", " aa_dict[residue.get_resname().strip()] \n", " for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " )\n", " sequence2 = [\n", " (res.id[1], res) for res in chain\n", " if res.get_resname().strip() in aa_dict\n", " ]\n", " \n", " # Prepare input for model prediction\n", " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n", " with torch.no_grad():\n", " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n", "\n", " # Calculate scores and normalize them\n", " scores = expit(outputs[:, 1] - outputs[:, 0])\n", " normalized_scores = normalize_scores(scores)\n", "\n", " # Zip residues with scores to track the residue ID and score\n", " residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]\n", " \n", " result_str = \"\\n\".join([\n", " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n", " ])\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, residue_scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if residue_scores is not None:\n", " # Sort residues based on their scores\n", " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", " \n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight medium-scoring residues only for the selected chain\n", " let midScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(resi) for resi in high_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in mid_score_residues),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " #pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"7RPZ\", \"A\"],\n", " [\"2IWI\", \"B\"],\n", " [\"2F6V\", \"A\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "b61d06ec-a4ee-4f65-925f-d2688730416a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4d67d69f-1f53-4bcc-8905-8d29384c4e20", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "\n", "from model_loader import load_model\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import pandas as pd\n", "import copy\n", "\n", "import transformers, datasets\n", "from transformers import AutoTokenizer\n", "from transformers import DataCollatorForTokenClassification\n", "\n", "from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "# Load model and move to device\n", "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n", "max_length = 1500\n", "model, tokenizer = load_model(checkpoint, max_length)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "model.eval()\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", " \n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n", " pdb_path = f'{pdb_id}.pdb'\n", " response = requests.get(pdb_url)\n", " if response.status_code == 200:\n", " with open(pdb_path, 'wb') as f:\n", " f.write(response.content)\n", " return pdb_path\n", " else:\n", " return None\n", "\n", "def process_pdb(pdb_id, segment):\n", " pdb_path = fetch_pdb(pdb_id)\n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " parser = PDBParser(QUIET=1)\n", " structure = parser.get_structure('protein', pdb_path)\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " \n", " aa_dict = {\n", " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n", " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n", " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n", " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n", " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n", " }\n", " \n", " # Exclude non-amino acid residues\n", " sequence = \"\".join(\n", " aa_dict[residue.get_resname().strip()] \n", " for residue in chain \n", " if residue.get_resname().strip() in aa_dict\n", " )\n", " sequence2 = [\n", " (res.id[1], res) for res in chain\n", " if res.get_resname().strip() in aa_dict\n", " ]\n", " \n", " # Prepare input for model prediction\n", " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n", " with torch.no_grad():\n", " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n", "\n", " # Calculate scores and normalize them\n", " scores = expit(outputs[:, 1] - outputs[:, 0])\n", " normalized_scores = normalize_scores(scores)\n", "\n", " # Zip residues with scores to track the residue ID and score\n", " residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]\n", " \n", " result_str = \"\\n\".join([\n", " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n", " ])\n", " \n", " # Save the predictions to a file\n", " prediction_file = f\"{pdb_id}_predictions.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n", "\n", "def molecule(input_pdb, residue_scores=None, segment='A'):\n", " mol = read_mol(input_pdb) # Read PDB file content\n", " \n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if residue_scores is not None:\n", " # Sort residues based on their scores\n", " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n", " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n", " \n", " high_score_script = \"\"\"\n", " // Reset all styles first\n", " viewer.getModel(0).setStyle({}, {});\n", " \n", " // Show only the selected chain\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\"}, \n", " { cartoon: {colorscheme:\"whiteCarbon\"} }\n", " );\n", " \n", " // Highlight high-scoring residues only for the selected chain\n", " let highScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n", " {\"stick\": {\"color\": \"red\"}}\n", " );\n", "\n", " // Highlight medium-scoring residues only for the selected chain\n", " let midScoreResidues = [%s];\n", " viewer.getModel(0).setStyle(\n", " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n", " {\"stick\": {\"color\": \"orange\"}}\n", " );\n", " \"\"\" % (segment, \n", " \", \".join(str(resi) for resi in high_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in mid_score_residues),\n", " segment)\n", " \n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "reps = [\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ]\n", "\n", "# Gradio UI\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction\")\n", " with gr.Row():\n", " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " visualize_btn = gr.Button(\"Visualize Structure\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n", "\n", " with gr.Row():\n", " #pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n", " prediction_btn = gr.Button(\"Predict Binding Site\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n", " download_output = gr.File(label=\"Download Predictions\")\n", " \n", " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n", " \n", " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n", " \n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"7RPZ\", \"A\"],\n", " [\"2IWI\", \"B\"],\n", " [\"2F6V\", \"A\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch(share=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python (LLM)", "language": "python", "name": "llm" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }