{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2c614d31-f96a-4164-8293-1cec9b0b2cd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7870\n",
      "* Running on public URL: https://c83ac6a1bebcdb4528.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://c83ac6a1bebcdb4528.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the specified chain and replace B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain\n",
    "    new_structure = structure.copy()\n",
    "    for model in new_structure:\n",
    "        # Remove all chains except the specified one\n",
    "        chains_to_remove = [chain for chain in model if chain.id != chain_id]\n",
    "        for chain in chains_to_remove:\n",
    "            model.detach_child(chain.id)\n",
    "    \n",
    "    # Create a modified PDB with scores in B-factor\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "    for model in new_structure:\n",
    "        for chain in model:\n",
    "            for residue in chain:\n",
    "                if residue.id[1] in scores_dict:\n",
    "                    for atom in residue:\n",
    "                        atom.bfactor = scores_dict[residue.id[1]] #* 100  # Scale score to B-factor range\n",
    "    \n",
    "    # Save the modified structure\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n",
    "    io = PDBIO()\n",
    "    io.set_structure(new_structure)\n",
    "    io.save(output_pdb)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    # Identify high and mid scoring residues\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
    "    mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
    "\n",
    "    # Calculate geometric center of high-scoring residues\n",
    "    geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n",
    "    pymol_selection = f\"select high_score_residues, resi {'+'.join(map(str, high_score_residues))} and chain {segment}\"\n",
    "    pymol_center_cmd = f\"show spheres, resi {'+'.join(map(str, high_score_residues))} and chain {segment}\" if geo_center is not None else \"\"\n",
    "\n",
    "    # Generate the result string\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues)])\n",
    "    \n",
    "    # Create prediction and scored PDB files\n",
    "    prediction_file = f\"{pdb_id}_predictions.txt\"\n",
    "    with open(prediction_file, \"w\") as f:\n",
    "        f.write(result_str)\n",
    "\n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores)\n",
    "\n",
    "    # Molecule visualization with updated script\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment)\n",
    "\n",
    "    # Construct PyMOL command suggestions\n",
    "    pymol_commands = f\"\"\"\n",
    "PyMOL Visualization Commands:\n",
    "1. Load PDB: load {os.path.abspath(pdb_path)}\n",
    "2. Select high-scoring residues: {pymol_selection}\n",
    "3. Highlight high-scoring residues: show sticks, high_score_residues\n",
    "{pymol_center_cmd}\n",
    "\"\"\"\n",
    "    \n",
    "    return result_str + \"\\n\\n\" + pymol_commands, mol_vis, [prediction_file, scored_pdb]\n",
    "\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
    "        mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let highScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        highScoreModel.setStyle({}, {});\n",
    "        highScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for medium-scoring residues and apply orange sticks style\n",
    "        let midScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        midScoreModel.setStyle({}, {});\n",
    "        midScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in high_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in mid_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "        visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    prediction_btn.click(\n",
    "        process_pdb, \n",
    "        inputs=[\n",
    "            pdb_input, \n",
    "            segment_input\n",
    "        ], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_pdb, \n",
    "        inputs=[pdb_input], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "db0b4763-5368-4d73-b5f6-d1c168f7fcd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7864\n",
      "* Running on public URL: https://060e61e5b829d9fb6e.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://060e61e5b829d9fb6e.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from Bio.PDB import Select\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain and selected residues\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n",
    "    \n",
    "    # Create scores dictionary for easy lookup\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "\n",
    "    # Create a custom Select class\n",
    "    class ResidueSelector(Select):\n",
    "        def __init__(self, chain_id, selected_residues, scores_dict):\n",
    "            self.chain_id = chain_id\n",
    "            self.selected_residues = selected_residues\n",
    "            self.scores_dict = scores_dict\n",
    "        \n",
    "        def accept_chain(self, chain):\n",
    "            return chain.id == self.chain_id\n",
    "        \n",
    "        def accept_residue(self, residue):\n",
    "            return residue.id[1] in self.selected_residues\n",
    "\n",
    "        def accept_atom(self, atom):\n",
    "            if atom.parent.id[1] in self.scores_dict:\n",
    "                atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n",
    "            return True\n",
    "\n",
    "    # Prepare output PDB with selected chain and residues, modified B-factors\n",
    "    io = PDBIO()\n",
    "    selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
    "    \n",
    "    io.set_structure(structure[0])\n",
    "    io.save(output_pdb, selector)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    # More granular scoring for visualization\n",
    "    def score_to_color(score):\n",
    "        if score <= 0.6:\n",
    "            return \"blue\"\n",
    "        elif score <= 0.7:\n",
    "            return \"lightblue\"\n",
    "        elif score <= 0.8:\n",
    "            return \"white\"\n",
    "        elif score <= 0.9:\n",
    "            return \"orange\"\n",
    "        elif score > 0.9:\n",
    "            return \"red\"\n",
    "\n",
    "    color_map = {resi: score_to_color(score) for resi, score in residue_scores}\n",
    "    \n",
    "    # Identify high scoring residues (> 0.7)\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.7]\n",
    "    mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.7]\n",
    "\n",
    "    # Calculate geometric center of high-scoring residues\n",
    "    geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n",
    "\n",
    "    # Preparing the result: only print high scoring residues\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\n\"\n",
    "    result_str += \"High-scoring Residues (Score > 0.7):\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n",
    "    ])\n",
    "    \n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
    "\n",
    "    # Molecule visualization with updated script with color mapping\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment) #, color_map)\n",
    "\n",
    "    # Improved PyMOL command suggestions\n",
    "    pymol_commands = f\"\"\"\n",
    "# PyMOL Visualization Commands\n",
    "load {os.path.abspath(pdb_path)}, protein\n",
    "hide everything, all\n",
    "show cartoon, chain {segment}\n",
    "color white, chain {segment}\n",
    "\"\"\"\n",
    "    \n",
    "    # Color specific residues\n",
    "    for score_range, color in [\n",
    "        (high_score_residues, \"red\"), \n",
    "        (mid_score_residues, \"orange\")\n",
    "    ]:\n",
    "        if score_range:\n",
    "            resi_list = '+'.join(map(str, score_range))\n",
    "            pymol_commands += f\"\"\"\n",
    "select high_score_residues, resi {resi_list} and chain {segment}\n",
    "show sticks, high_score_residues\n",
    "color {color}, high_score_residues\n",
    "\"\"\"\n",
    "    \n",
    "    return result_str, mol_vis, [scored_pdb]\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
    "        mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let highScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        highScoreModel.setStyle({}, {});\n",
    "        highScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for medium-scoring residues and apply orange sticks style\n",
    "        let midScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        midScoreModel.setStyle({}, {});\n",
    "        midScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in high_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in mid_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "        visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    prediction_btn.click(\n",
    "        process_pdb, \n",
    "        inputs=[\n",
    "            pdb_input, \n",
    "            segment_input\n",
    "        ], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_pdb, \n",
    "        inputs=[pdb_input], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d0d50415-1304-462d-a176-b58f394e79b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7866\n",
      "* Running on public URL: https://a9ff499df0a5f7be8c.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://a9ff499df0a5f7be8c.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from Bio.PDB import Select\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain and selected residues\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n",
    "    \n",
    "    # Create scores dictionary for easy lookup\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "\n",
    "    # Create a custom Select class\n",
    "    class ResidueSelector(Select):\n",
    "        def __init__(self, chain_id, selected_residues, scores_dict):\n",
    "            self.chain_id = chain_id\n",
    "            self.selected_residues = selected_residues\n",
    "            self.scores_dict = scores_dict\n",
    "        \n",
    "        def accept_chain(self, chain):\n",
    "            return chain.id == self.chain_id\n",
    "        \n",
    "        def accept_residue(self, residue):\n",
    "            return residue.id[1] in self.selected_residues\n",
    "\n",
    "        def accept_atom(self, atom):\n",
    "            if atom.parent.id[1] in self.scores_dict:\n",
    "                atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n",
    "            return True\n",
    "\n",
    "    # Prepare output PDB with selected chain and residues, modified B-factors\n",
    "    io = PDBIO()\n",
    "    selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
    "    \n",
    "    io.set_structure(structure[0])\n",
    "    io.save(output_pdb, selector)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    # More granular scoring for visualization\n",
    "    def score_to_color(score):\n",
    "        if score <= 0.6:\n",
    "            return \"blue\"\n",
    "        elif score <= 0.7:\n",
    "            return \"lightblue\"\n",
    "        elif score <= 0.8:\n",
    "            return \"white\"\n",
    "        elif score <= 0.9:\n",
    "            return \"orange\"\n",
    "        elif score > 0.9:\n",
    "            return \"red\"\n",
    "\n",
    "    color_map = {resi: score_to_color(score) for resi, score in residue_scores}\n",
    "    \n",
    "    # Identify high scoring residues (> 0.7)\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.7]\n",
    "    mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.7]\n",
    "\n",
    "    # Calculate geometric center of high-scoring residues\n",
    "    geo_center = calculate_geometric_center(pdb_path, high_score_residues, segment)\n",
    "\n",
    "    # Preparing the result: only print high scoring residues\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\n\"\n",
    "    result_str += \"High-scoring Residues (Score > 0.7):\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n",
    "    ])\n",
    "\n",
    "    # Create prediction and scored PDB files\n",
    "    prediction_file = f\"{pdb_id}_predictions.txt\"\n",
    "    with open(prediction_file, \"w\") as f:\n",
    "        f.write(result_str)\n",
    "    \n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
    "\n",
    "    # Molecule visualization with updated script with color mapping\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
    "\n",
    "    # Improved PyMOL command suggestions\n",
    "    pymol_commands = f\"\"\"\n",
    "# PyMOL Visualization Commands\n",
    "load {os.path.abspath(pdb_path)}, protein\n",
    "hide everything, all\n",
    "show cartoon, chain {segment}\n",
    "color white, chain {segment}\n",
    "\"\"\"\n",
    "    \n",
    "    # Color specific residues\n",
    "    for score_range, color in [\n",
    "        (high_score_residues, \"red\"), \n",
    "        (mid_score_residues, \"orange\")\n",
    "    ]:\n",
    "        if score_range:\n",
    "            resi_list = '+'.join(map(str, score_range))\n",
    "            pymol_commands += f\"\"\"\n",
    "select high_score_residues, resi {resi_list} and chain {segment}\n",
    "show sticks, high_score_residues\n",
    "color {color}, high_score_residues\n",
    "\"\"\"\n",
    "    \n",
    "    return result_str, mol_vis, [prediction_file,scored_pdb]\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
    "        mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let highScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        highScoreModel.setStyle({}, {});\n",
    "        highScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for medium-scoring residues and apply orange sticks style\n",
    "        let midScoreModel = viewer.addModel(pdb, \"pdb\");\n",
    "        midScoreModel.setStyle({}, {});\n",
    "        midScoreModel.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in high_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in mid_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "        visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    prediction_btn.click(\n",
    "        process_pdb, \n",
    "        inputs=[\n",
    "            pdb_input, \n",
    "            segment_input\n",
    "        ], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_pdb, \n",
    "        inputs=[pdb_input], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "004ab20c-5273-44b9-bc69-d41f236296e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7890\n",
      "* Running on public URL: https://a7f63d297aa65a70de.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://a7f63d297aa65a70de.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from Bio.PDB import Select\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain and selected residues\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_scored.pdb\"\n",
    "    \n",
    "    # Create scores dictionary for easy lookup\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "\n",
    "    # Create a custom Select class\n",
    "    class ResidueSelector(Select):\n",
    "        def __init__(self, chain_id, selected_residues, scores_dict):\n",
    "            self.chain_id = chain_id\n",
    "            self.selected_residues = selected_residues\n",
    "            self.scores_dict = scores_dict\n",
    "        \n",
    "        def accept_chain(self, chain):\n",
    "            return chain.id == self.chain_id\n",
    "        \n",
    "        def accept_residue(self, residue):\n",
    "            return residue.id[1] in self.selected_residues\n",
    "\n",
    "        def accept_atom(self, atom):\n",
    "            if atom.parent.id[1] in self.scores_dict:\n",
    "                atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n",
    "            return True\n",
    "\n",
    "    # Prepare output PDB with selected chain and residues, modified B-factors\n",
    "    io = PDBIO()\n",
    "    selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
    "    \n",
    "    io.set_structure(structure[0])\n",
    "    io.save(output_pdb, selector)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    \n",
    "    # Identify high scoring residues (> 0.5)\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n",
    "    \n",
    "    # Preparing the result: only print high scoring residues\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n",
    "    result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n",
    "    ])\n",
    "\n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
    "\n",
    "    # Molecule visualization with updated script with color mapping\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
    "\n",
    "    # Improved PyMOL command suggestions\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    \n",
    "    pymol_commands += f\"\"\"\n",
    "# PyMOL Visualization Commands\n",
    "load {os.path.abspath(pdb_path)}, protein\n",
    "hide everything, all\n",
    "show cartoon, chain {segment}\n",
    "color white, chain {segment}\n",
    "\"\"\"\n",
    "    \n",
    "    # Color specific residues\n",
    "    for score_range, color in [\n",
    "        (high_score_residues, \"red\")\n",
    "    ]:\n",
    "        if score_range:\n",
    "            resi_list = '+'.join(map(str, score_range))\n",
    "            pymol_commands += f\"\"\"\n",
    "select high_score_residues, resi {resi_list} and chain {segment}\n",
    "show sticks, high_score_residues\n",
    "color {color}, high_score_residues\n",
    "\"\"\"\n",
    "    # Create prediction and scored PDB files\n",
    "    prediction_file = f\"{pdb_id}_predictions.txt\"\n",
    "    with open(prediction_file, \"w\") as f:\n",
    "        f.write(result_str)\n",
    "    \n",
    "    return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    # More granular scoring for visualization\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n",
    "        class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n",
    "        class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n",
    "        class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n",
    "        class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class1Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class1Model.setStyle({}, {});\n",
    "        class1Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"blue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class2Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class2Model.setStyle({}, {});\n",
    "        class2Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"lightblue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class3Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class3Model.setStyle({}, {});\n",
    "        class3Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class4Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class4Model.setStyle({}, {});\n",
    "        class4Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class5Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class5Model.setStyle({}, {});\n",
    "        class5Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class1_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class2_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class3_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class4_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class5_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "        visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    prediction_btn.click(\n",
    "        process_pdb, \n",
    "        inputs=[\n",
    "            pdb_input, \n",
    "            segment_input\n",
    "        ], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_pdb, \n",
    "        inputs=[pdb_input], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "a492c5c5-e0aa-4445-9375-64cfdb963e04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7891\n",
      "* Running on public URL: https://339346b4ad32f608d0.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://339346b4ad32f608d0.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from Bio.PDB import Select\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain and selected residues\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
    "    \n",
    "    # Create scores dictionary for easy lookup\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "\n",
    "    # Create a custom Select class\n",
    "    class ResidueSelector(Select):\n",
    "        def __init__(self, chain_id, selected_residues, scores_dict):\n",
    "            self.chain_id = chain_id\n",
    "            self.selected_residues = selected_residues\n",
    "            self.scores_dict = scores_dict\n",
    "        \n",
    "        def accept_chain(self, chain):\n",
    "            return chain.id == self.chain_id\n",
    "        \n",
    "        def accept_residue(self, residue):\n",
    "            return residue.id[1] in self.selected_residues\n",
    "\n",
    "        def accept_atom(self, atom):\n",
    "            if atom.parent.id[1] in self.scores_dict:\n",
    "                atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n",
    "            return True\n",
    "\n",
    "    # Prepare output PDB with selected chain and residues, modified B-factors\n",
    "    io = PDBIO()\n",
    "    selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
    "    \n",
    "    io.set_structure(structure[0])\n",
    "    io.save(output_pdb, selector)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    \n",
    "    # Identify high scoring residues (> 0.5)\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n",
    "    \n",
    "    # Preparing the result: only print high scoring residues\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n",
    "    result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n",
    "    ])\n",
    "\n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
    "\n",
    "    # Molecule visualization with updated script with color mapping\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
    "\n",
    "    # Improved PyMOL command suggestions\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    \n",
    "    pymol_commands += f\"\"\"\n",
    "# PyMOL Visualization Commands\n",
    "load {os.path.abspath(pdb_path)}, protein\n",
    "hide everything, all\n",
    "show cartoon, chain {segment}\n",
    "color white, chain {segment}\n",
    "\"\"\"\n",
    "    \n",
    "    # Color specific residues\n",
    "    for score_range, color in [\n",
    "        (high_score_residues, \"red\")\n",
    "    ]:\n",
    "        if score_range:\n",
    "            resi_list = '+'.join(map(str, score_range))\n",
    "            pymol_commands += f\"\"\"\n",
    "select high_score_residues, resi {resi_list} and chain {segment}\n",
    "show sticks, high_score_residues\n",
    "color {color}, high_score_residues\n",
    "\"\"\"\n",
    "    # Create prediction and scored PDB files\n",
    "    prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
    "    with open(prediction_file, \"w\") as f:\n",
    "        f.write(result_str)\n",
    "    \n",
    "    return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    # More granular scoring for visualization\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n",
    "        class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n",
    "        class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n",
    "        class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n",
    "        class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class1Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class1Model.setStyle({}, {});\n",
    "        class1Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"blue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class2Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class2Model.setStyle({}, {});\n",
    "        class2Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"lightblue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class3Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class3Model.setStyle({}, {});\n",
    "        class3Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class4Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class4Model.setStyle({}, {});\n",
    "        class4Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class5Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class5Model.setStyle({}, {});\n",
    "        class5Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class1_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class2_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class3_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class4_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class5_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "        visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    explanation_vis = gr.Markdown(\"\"\"\n",
    "    Residues with a score > 0.5 are considered binding sites and represented as sticks with the score dependent colorcoding:\n",
    "    - 0.5-0.6: blue  \n",
    "    - 0.6–0.7: light blue  \n",
    "    - 0.7–0.8: white\n",
    "    - 0.8–0.9: orange\n",
    "    - 0.9–1.0: red\n",
    "    \"\"\")\n",
    "    predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n",
    "    gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    prediction_btn.click(\n",
    "        process_pdb, \n",
    "        inputs=[\n",
    "            pdb_input, \n",
    "            segment_input\n",
    "        ], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_pdb, \n",
    "        inputs=[pdb_input], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "99d18e7c-3ec1-48f2-b368-958b66bb1782",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7923\n",
      "* Running on public URL: https://ad5916147a5fd9c4b5.gradio.live\n",
      "\n",
      "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"https://ad5916147a5fd9c4b5.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".cif\n",
      "./7RPZ.pdb\n",
      ".pdb\n",
      "/private/var/folders/tm/ym2tckv54b96ws82y3b7cqhh0000gn/T/gradio/6b7bd3b706f978096c02bacdbf7b38529f0a5233f7570f758063b6e78f62771d/2F6V.pdb\n"
     ]
    }
   ],
   "source": [
    "from datetime import datetime\n",
    "import gradio as gr\n",
    "import requests\n",
    "from Bio.PDB import PDBParser, MMCIFParser, PDBIO\n",
    "from Bio.PDB.Polypeptide import is_aa\n",
    "from Bio.SeqUtils import seq1\n",
    "from Bio.PDB import Select\n",
    "from typing import Optional, Tuple\n",
    "import numpy as np\n",
    "import os\n",
    "from gradio_molecule3d import Molecule3D\n",
    "\n",
    "import re\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "from scipy.special import expit\n",
    "\n",
    "def normalize_scores(scores):\n",
    "    min_score = np.min(scores)\n",
    "    max_score = np.max(scores)\n",
    "    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
    "\n",
    "def read_mol(pdb_path):\n",
    "    \"\"\"Read PDB file and return its content as a string\"\"\"\n",
    "    with open(pdb_path, 'r') as f:\n",
    "        return f.read()\n",
    "\n",
    "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
    "    If a structure file already exists locally, it uses that.\n",
    "    \"\"\"\n",
    "    file_path = download_structure(pdb_id, output_dir)\n",
    "    if file_path:\n",
    "        return file_path\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
    "    \"\"\"\n",
    "    Attempt to download the structure file in CIF or PDB format.\n",
    "    Returns the path to the downloaded file, or None if download fails.\n",
    "    \"\"\"\n",
    "    for ext in ['.cif', '.pdb']:\n",
    "        file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
    "        if os.path.exists(file_path):\n",
    "            return file_path\n",
    "        url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
    "        try:\n",
    "            response = requests.get(url, timeout=10)\n",
    "            if response.status_code == 200:\n",
    "                with open(file_path, 'wb') as f:\n",
    "                    f.write(response.content)\n",
    "                return file_path\n",
    "        except Exception as e:\n",
    "            print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
    "    return None\n",
    "\n",
    "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
    "    \"\"\"\n",
    "    Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
    "    \"\"\"\n",
    "    pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
    "    parser = MMCIFParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', cif_path)\n",
    "    io = PDBIO()\n",
    "    io.set_structure(structure)\n",
    "    io.save(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def fetch_pdb(pdb_id):\n",
    "    pdb_path = fetch_structure(pdb_id)\n",
    "    if not pdb_path:\n",
    "        return None\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    if ext == '.cif':\n",
    "        pdb_path = convert_cif_to_pdb(pdb_path)\n",
    "    return pdb_path\n",
    "\n",
    "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
    "    \"\"\"\n",
    "    Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
    "    \"\"\"\n",
    "    # Read the original PDB file\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', input_pdb)\n",
    "    \n",
    "    # Prepare a new structure with only the specified chain and selected residues\n",
    "    output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
    "    \n",
    "    # Create scores dictionary for easy lookup\n",
    "    scores_dict = {resi: score for resi, score in residue_scores}\n",
    "\n",
    "    # Create a custom Select class\n",
    "    class ResidueSelector(Select):\n",
    "        def __init__(self, chain_id, selected_residues, scores_dict):\n",
    "            self.chain_id = chain_id\n",
    "            self.selected_residues = selected_residues\n",
    "            self.scores_dict = scores_dict\n",
    "        \n",
    "        def accept_chain(self, chain):\n",
    "            return chain.id == self.chain_id\n",
    "        \n",
    "        def accept_residue(self, residue):\n",
    "            return residue.id[1] in self.selected_residues\n",
    "\n",
    "        def accept_atom(self, atom):\n",
    "            if atom.parent.id[1] in self.scores_dict:\n",
    "                atom.bfactor = self.scores_dict[atom.parent.id[1]] * 100\n",
    "            return True\n",
    "\n",
    "    # Prepare output PDB with selected chain and residues, modified B-factors\n",
    "    io = PDBIO()\n",
    "    selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
    "    \n",
    "    io.set_structure(structure[0])\n",
    "    io.save(output_pdb, selector)\n",
    "    \n",
    "    return output_pdb\n",
    "\n",
    "def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):\n",
    "    \"\"\"\n",
    "    Calculate the geometric center of high-scoring residues\n",
    "    \"\"\"\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    structure = parser.get_structure('protein', pdb_path)\n",
    "    \n",
    "    # Collect coordinates of CA atoms from high-scoring residues\n",
    "    coords = []\n",
    "    for model in structure:\n",
    "        for chain in model:\n",
    "            if chain.id == chain_id:\n",
    "                for residue in chain:\n",
    "                    if residue.id[1] in high_score_residues:\n",
    "                        if 'CA' in residue:  # Use alpha carbon as representative\n",
    "                            ca_atom = residue['CA']\n",
    "                            coords.append(ca_atom.coord)\n",
    "    \n",
    "    # Calculate geometric center\n",
    "    if coords:\n",
    "        center = np.mean(coords, axis=0)\n",
    "        return center\n",
    "    return None\n",
    "\n",
    "def process_pdb(pdb_id_or_file, segment):\n",
    "    # Determine if input is a PDB ID or file path\n",
    "    if pdb_id_or_file.endswith('.pdb'):\n",
    "        pdb_path = pdb_id_or_file\n",
    "        pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
    "    else:\n",
    "        pdb_id = pdb_id_or_file\n",
    "        pdb_path = fetch_pdb(pdb_id)\n",
    "    \n",
    "    if not pdb_path:\n",
    "        return \"Failed to fetch PDB file\", None, None\n",
    "    \n",
    "    # Determine the file format and choose the appropriate parser\n",
    "    _, ext = os.path.splitext(pdb_path)\n",
    "    parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
    "    \n",
    "    try:\n",
    "        # Parse the structure file\n",
    "        structure = parser.get_structure('protein', pdb_path)\n",
    "    except Exception as e:\n",
    "        return f\"Error parsing structure file: {e}\", None, None\n",
    "    \n",
    "    # Extract the specified chain\n",
    "    try:\n",
    "        chain = structure[0][segment]\n",
    "    except KeyError:\n",
    "        return \"Invalid Chain ID\", None, None\n",
    "    \n",
    "    protein_residues = [res for res in chain if is_aa(res)]\n",
    "    sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
    "    sequence_id = [res.id[1] for res in protein_residues]\n",
    "    \n",
    "    scores = np.random.rand(len(sequence))\n",
    "    normalized_scores = normalize_scores(scores)\n",
    "    \n",
    "    # Zip residues with scores to track the residue ID and score\n",
    "    residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
    "\n",
    "    \n",
    "    # Identify high scoring residues (> 0.5)\n",
    "    high_score_residues = [resi for resi, score in residue_scores if score > 0.5]\n",
    "    \n",
    "    # Preparing the result: only print high scoring residues\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    result_str += \"High-scoring Residues (Score > 0.5):\\n\"\n",
    "    result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\\n\"\n",
    "    result_str += \"\\n\".join([\n",
    "        f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
    "        for i, res in enumerate(protein_residues) if res.id[1] in high_score_residues\n",
    "    ])\n",
    "\n",
    "    # Create chain-specific PDB with scores in B-factor\n",
    "    scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
    "\n",
    "    # Molecule visualization with updated script with color mapping\n",
    "    mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
    "\n",
    "    # Improved PyMOL command suggestions\n",
    "    current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
    "    pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
    "    \n",
    "    pymol_commands += f\"\"\"\n",
    "# PyMOL Visualization Commands\n",
    "load {os.path.abspath(pdb_path)}, protein\n",
    "hide everything, all\n",
    "show cartoon, chain {segment}\n",
    "color white, chain {segment}\n",
    "\"\"\"\n",
    "    \n",
    "    # Color specific residues\n",
    "    for score_range, color in [\n",
    "        (high_score_residues, \"red\")\n",
    "    ]:\n",
    "        if score_range:\n",
    "            resi_list = '+'.join(map(str, score_range))\n",
    "            pymol_commands += f\"\"\"\n",
    "select high_score_residues, resi {resi_list} and chain {segment}\n",
    "show sticks, high_score_residues\n",
    "color {color}, high_score_residues\n",
    "\"\"\"\n",
    "    # Create prediction and scored PDB files\n",
    "    prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
    "    with open(prediction_file, \"w\") as f:\n",
    "        f.write(result_str)\n",
    "    \n",
    "    return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n",
    "\n",
    "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
    "    # More granular scoring for visualization\n",
    "    mol = read_mol(input_pdb)  # Read PDB file content\n",
    "\n",
    "    # Prepare high-scoring residues script if scores are provided\n",
    "    high_score_script = \"\"\n",
    "    if residue_scores is not None:\n",
    "        # Filter residues based on their scores\n",
    "        class1_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.6]\n",
    "        class2_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.7]\n",
    "        class3_score_residues = [resi for resi, score in residue_scores if 0.7 < score <= 0.8]\n",
    "        class4_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n",
    "        class5_score_residues = [resi for resi, score in residue_scores if 0.9 < score <= 1.0]\n",
    "        \n",
    "        high_score_script = \"\"\"\n",
    "        // Load the original model and apply white cartoon style\n",
    "        let chainModel = viewer.addModel(pdb, \"pdb\");\n",
    "        chainModel.setStyle({}, {});\n",
    "        chainModel.setStyle(\n",
    "            {\"chain\": \"%s\"}, \n",
    "            {\"cartoon\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class1Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class1Model.setStyle({}, {});\n",
    "        class1Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"blue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class2Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class2Model.setStyle({}, {});\n",
    "        class2Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"lightblue\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class3Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class3Model.setStyle({}, {});\n",
    "        class3Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"white\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class4Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class4Model.setStyle({}, {});\n",
    "        class4Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"orange\"}}\n",
    "        );\n",
    "\n",
    "        // Create a new model for high-scoring residues and apply red sticks style\n",
    "        let class5Model = viewer.addModel(pdb, \"pdb\");\n",
    "        class5Model.setStyle({}, {});\n",
    "        class5Model.setStyle(\n",
    "            {\"chain\": \"%s\", \"resi\": [%s]}, \n",
    "            {\"stick\": {\"color\": \"red\"}}\n",
    "        );\n",
    "\n",
    "        \"\"\" % (\n",
    "            segment,\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class1_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class2_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class3_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class4_score_residues),\n",
    "            segment,\n",
    "            \", \".join(str(resi) for resi in class5_score_residues)\n",
    "        )\n",
    "    \n",
    "    # Generate the full HTML content\n",
    "    html_content = f\"\"\"\n",
    "    <!DOCTYPE html>\n",
    "    <html>\n",
    "    <head>    \n",
    "        <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
    "        <style>\n",
    "        .mol-container {{\n",
    "            width: 100%;\n",
    "            height: 700px;\n",
    "            position: relative;\n",
    "        }}\n",
    "        </style>\n",
    "        <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
    "        <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
    "    </head>\n",
    "    <body>\n",
    "        <div id=\"container\" class=\"mol-container\"></div>\n",
    "        <script>\n",
    "            let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
    "            $(document).ready(function () {{\n",
    "                let element = $(\"#container\");\n",
    "                let config = {{ backgroundColor: \"white\" }};\n",
    "                let viewer = $3Dmol.createViewer(element, config);\n",
    "                \n",
    "                {high_score_script}\n",
    "                \n",
    "                // Add hover functionality\n",
    "                viewer.setHoverable(\n",
    "                    {{}}, \n",
    "                    true, \n",
    "                    function(atom, viewer, event, container) {{\n",
    "                        if (!atom.label) {{\n",
    "                            atom.label = viewer.addLabel(\n",
    "                                atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
    "                                {{\n",
    "                                    position: atom, \n",
    "                                    backgroundColor: 'mintcream', \n",
    "                                    fontColor: 'black',\n",
    "                                    fontSize: 12,\n",
    "                                    padding: 2\n",
    "                                }}\n",
    "                            );\n",
    "                        }}\n",
    "                    }},\n",
    "                    function(atom, viewer) {{\n",
    "                        if (atom.label) {{\n",
    "                            viewer.removeLabel(atom.label);\n",
    "                            delete atom.label;\n",
    "                        }}\n",
    "                    }}\n",
    "                );\n",
    "                \n",
    "                viewer.zoomTo();\n",
    "                viewer.render();\n",
    "                viewer.zoom(0.8, 2000);\n",
    "            }});\n",
    "        </script>\n",
    "    </body>\n",
    "    </html>\n",
    "    \"\"\"\n",
    "    \n",
    "    # Return the HTML content within an iframe safely encoded for special characters\n",
    "    return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
    "\n",
    "# Gradio UI\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# Protein Binding Site Prediction\")\n",
    "    \n",
    "    # Mode selection\n",
    "    mode = gr.Radio(\n",
    "        choices=[\"PDB ID\", \"Upload File\"],\n",
    "        value=\"PDB ID\",\n",
    "        label=\"Input Mode\",\n",
    "        info=\"Choose whether to input a PDB ID or upload a PDB/CIF file.\"\n",
    "    )\n",
    "\n",
    "    # Input components based on mode\n",
    "    pdb_input = gr.Textbox(value=\"4BDU\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
    "    pdb_file = gr.File(label=\"Upload PDB/CIF File\", visible=False)\n",
    "    visualize_btn = gr.Button(\"Visualize Structure\")\n",
    "\n",
    "    molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
    "        {\n",
    "            \"model\": 0,\n",
    "            \"style\": \"cartoon\",\n",
    "            \"color\": \"whiteCarbon\",\n",
    "            \"residue_range\": \"\",\n",
    "            \"around\": 0,\n",
    "            \"byres\": False,\n",
    "        }\n",
    "    ])\n",
    "\n",
    "    with gr.Row():\n",
    "        segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
    "        prediction_btn = gr.Button(\"Predict Binding Site\")\n",
    "\n",
    "    molecule_output = gr.HTML(label=\"Protein Structure\")\n",
    "    explanation_vis = gr.Markdown(\"\"\"\n",
    "    Residues with a score > 0.5 are considered binding sites and represented as sticks with the score dependent colorcoding:\n",
    "    - 0.5-0.6: blue  \n",
    "    - 0.6–0.7: light blue  \n",
    "    - 0.7–0.8: white\n",
    "    - 0.8–0.9: orange\n",
    "    - 0.9–1.0: red\n",
    "    \"\"\")\n",
    "    predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n",
    "    gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n",
    "    download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
    "    \n",
    "    def process_interface(mode, pdb_id, pdb_file, chain_id):\n",
    "        if mode == \"PDB ID\":\n",
    "            return process_pdb(pdb_id, chain_id)\n",
    "        elif mode == \"Upload File\":\n",
    "            _, ext = os.path.splitext(pdb_file.name)\n",
    "            file_path = os.path.join('./', f\"{_}{ext}\")\n",
    "            if ext == '.cif':\n",
    "                pdb_path = convert_cif_to_pdb(file_path)\n",
    "            else:\n",
    "                pdb_path= file_path\n",
    "            return process_pdb(pdb_path, chain_id)\n",
    "        else:\n",
    "            return \"Error: Invalid mode selected\", None, None\n",
    "\n",
    "    def fetch_interface(mode, pdb_id, pdb_file):\n",
    "        if mode == \"PDB ID\":\n",
    "            return fetch_pdb(pdb_id)\n",
    "        elif mode == \"Upload File\":\n",
    "            _, ext = os.path.splitext(pdb_file.name)\n",
    "            file_path = os.path.join('./', f\"{_}{ext}\")\n",
    "            #print(ext)\n",
    "            if ext == '.cif':\n",
    "                pdb_path = convert_cif_to_pdb(file_path)\n",
    "            else:\n",
    "                pdb_path= file_path\n",
    "            #print(pdb_path)\n",
    "            return pdb_path\n",
    "        else:\n",
    "            return \"Error: Invalid mode selected\"\n",
    "\n",
    "    def toggle_mode(selected_mode):\n",
    "        if selected_mode == \"PDB ID\":\n",
    "            return gr.update(visible=True), gr.update(visible=False)\n",
    "        else:\n",
    "            return gr.update(visible=False), gr.update(visible=True)\n",
    "\n",
    "    mode.change(\n",
    "        toggle_mode,\n",
    "        inputs=[mode],\n",
    "        outputs=[pdb_input, pdb_file]\n",
    "    )\n",
    "\n",
    "    prediction_btn.click(\n",
    "        process_interface, \n",
    "        inputs=[mode, pdb_input, pdb_file, segment_input], \n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "    visualize_btn.click(\n",
    "        fetch_interface, \n",
    "        inputs=[mode, pdb_input, pdb_file], \n",
    "        outputs=molecule_output2\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"## Examples\")\n",
    "    gr.Examples(\n",
    "        examples=[\n",
    "            [\"7RPZ\", \"A\"],\n",
    "            [\"2IWI\", \"B\"],\n",
    "            [\"2F6V\", \"A\"]\n",
    "        ],\n",
    "        inputs=[pdb_input, segment_input],\n",
    "        outputs=[predictions_output, molecule_output, download_output]\n",
    "    )\n",
    "\n",
    "demo.launch(share=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9496cf4f-9e5f-4b0b-bb0d-7aebbb748ae6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (LLM)",
   "language": "python",
   "name": "llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}