{ "cells": [ { "cell_type": "code", "execution_count": 29, "id": "e776d9d6-417e-46d4-8061-846c055e1f8a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7873\n", "* Running on public URL: https://120000a6aa9d78e04c.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datetime import datetime\n", "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select\n", "from Bio.PDB.Polypeptide import is_aa\n", "from Bio.SeqUtils import seq1\n", "from typing import Optional, Tuple\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "\n", "#from model_loader import load_model\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "\n", "import re\n", "import pandas as pd\n", "import copy\n", "\n", "#import transformers\n", "#from transformers import AutoTokenizer, DataCollatorForTokenClassification\n", "\n", "#from datasets import Dataset\n", "\n", "from scipy.special import expit\n", "\n", "def normalize_scores(scores):\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", "\n", "def read_mol(pdb_path):\n", " \"\"\"Read PDB file and return its content as a string\"\"\"\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", "\n", "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", " \"\"\"\n", " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n", " If a structure file already exists locally, it uses that.\n", " \"\"\"\n", " file_path = download_structure(pdb_id, output_dir)\n", " if file_path:\n", " return file_path\n", " else:\n", " return None\n", "\n", "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n", " \"\"\"\n", " Attempt to download the structure file in CIF or PDB format.\n", " Returns the path to the downloaded file, or None if download fails.\n", " \"\"\"\n", " for ext in ['.cif', '.pdb']:\n", " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n", " if os.path.exists(file_path):\n", " return file_path\n", " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n", " try:\n", " response = requests.get(url, timeout=10)\n", " if response.status_code == 200:\n", " with open(file_path, 'wb') as f:\n", " f.write(response.content)\n", " return file_path\n", " except Exception as e:\n", " print(f\"Download error for {pdb_id}{ext}: {e}\")\n", " return None\n", "\n", "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", " \"\"\"\n", " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n", " \"\"\"\n", " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", " parser = MMCIFParser(QUIET=True)\n", " structure = parser.get_structure('protein', cif_path)\n", " io = PDBIO()\n", " io.set_structure(structure)\n", " io.save(pdb_path)\n", " return pdb_path\n", "\n", "def fetch_pdb(pdb_id):\n", " pdb_path = fetch_structure(pdb_id)\n", " if not pdb_path:\n", " return None\n", " _, ext = os.path.splitext(pdb_path)\n", " if ext == '.cif':\n", " pdb_path = convert_cif_to_pdb(pdb_path)\n", " return pdb_path\n", "\n", "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", " \"\"\"\n", " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n", " \"\"\"\n", " # Read the original PDB file\n", " parser = PDBParser(QUIET=True)\n", " structure = parser.get_structure('protein', input_pdb)\n", " \n", " # Prepare a new structure with only the specified chain and selected residues\n", " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n", " \n", " # Create scores dictionary for easy lookup\n", " scores_dict = {resi: score for resi, score in residue_scores}\n", "\n", " # Create a custom Select class\n", " class ResidueSelector(Select):\n", " def __init__(self, chain_id, selected_residues, scores_dict):\n", " self.chain_id = chain_id\n", " self.selected_residues = selected_residues\n", " self.scores_dict = scores_dict\n", " \n", " def accept_chain(self, chain):\n", " return chain.id == self.chain_id\n", " \n", " def accept_residue(self, residue):\n", " return residue.id[1] in self.selected_residues\n", "\n", " def accept_atom(self, atom):\n", " if atom.parent.id[1] in self.scores_dict:\n", " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n", " return True\n", "\n", " # Prepare output PDB with selected chain and residues, modified B-factors\n", " io = PDBIO()\n", " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", " \n", " io.set_structure(structure[0])\n", " io.save(output_pdb, selector)\n", " \n", " return output_pdb\n", "\n", "def process_pdb(pdb_id_or_file, segment):\n", " # Determine if input is a PDB ID or file path\n", " if pdb_id_or_file.endswith('.pdb'):\n", " pdb_path = pdb_id_or_file\n", " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", " else:\n", " pdb_id = pdb_id_or_file\n", " pdb_path = fetch_pdb(pdb_id)\n", " \n", " if not pdb_path:\n", " return \"Failed to fetch PDB file\", None, None\n", " \n", " # Determine the file format and choose the appropriate parser\n", " _, ext = os.path.splitext(pdb_path)\n", " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n", " \n", " try:\n", " # Parse the structure file\n", " structure = parser.get_structure('protein', pdb_path)\n", " except Exception as e:\n", " return f\"Error parsing structure file: {e}\", None, None\n", " \n", " # Extract the specified chain\n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " return \"Invalid Chain ID\", None, None\n", " \n", " protein_residues = [res for res in chain if is_aa(res)]\n", " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", " sequence_id = [res.id[1] for res in protein_residues]\n", "\n", " visualized_sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", " if sequence != visualized_sequence:\n", " raise ValueError(\"The visualized sequence does not match the prediction sequence\")\n", " \n", " scores = np.random.rand(len(sequence))\n", " normalized_scores = normalize_scores(scores)\n", " \n", " # Zip residues with scores to track the residue ID and score\n", " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", "\n", " \n", " # Define the score brackets\n", " score_brackets = {\n", " \"0.0-0.2\": (0.0, 0.2),\n", " \"0.2-0.4\": (0.2, 0.4),\n", " \"0.4-0.6\": (0.4, 0.6),\n", " \"0.6-0.8\": (0.6, 0.8),\n", " \"0.8-1.0\": (0.8, 1.0)\n", " }\n", " \n", " # Initialize a dictionary to store residues by bracket\n", " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n", " \n", " # Categorize residues into brackets\n", " for resi, score in residue_scores:\n", " for bracket, (lower, upper) in score_brackets.items():\n", " if lower <= score < upper:\n", " residues_by_bracket[bracket].append(resi)\n", " break\n", " \n", " # Preparing the result string\n", " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", " result_str += \"Residues by Score Brackets:\\n\\n\"\n", " \n", " # Add residues for each bracket\n", " for bracket, residues in residues_by_bracket.items():\n", " result_str += f\"Bracket {bracket}:\\n\"\n", " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\"\n", " result_str += \"\\n\".join([\n", " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n", " for i, res in enumerate(protein_residues) if res.id[1] in residues\n", " ])\n", " result_str += \"\\n\\n\"\n", "\n", " # Create chain-specific PDB with scores in B-factor\n", " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", "\n", " # Molecule visualization with updated script with color mapping\n", " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n", "\n", " # Improved PyMOL command suggestions\n", " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n", " \n", " pymol_commands += f\"\"\"\n", " # PyMOL Visualization Commands\n", " load {os.path.abspath(pdb_path)}, protein\n", " hide everything, all\n", " show cartoon, chain {segment}\n", " color white, chain {segment}\n", " \"\"\"\n", " \n", " # Define colors for each score bracket\n", " bracket_colors = {\n", " \"0.0-0.2\": \"white\",\n", " \"0.2-0.4\": \"lightorange\",\n", " \"0.4-0.6\": \"orange\",\n", " \"0.6-0.8\": \"orangered\",\n", " \"0.8-1.0\": \"red\"\n", " }\n", " \n", " # Add PyMOL commands for each score bracket\n", " for bracket, residues in residues_by_bracket.items():\n", " if residues: # Only add commands if there are residues in this bracket\n", " color = bracket_colors[bracket]\n", " resi_list = '+'.join(map(str, residues))\n", " pymol_commands += f\"\"\"\n", " select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}\n", " show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}\n", " color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}\n", " \"\"\"\n", " \n", " # Create prediction and scored PDB files\n", " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", " \n", " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n", "\n", "def molecule(input_pdb, residue_scores=None, segment='A'):\n", " # More granular scoring for visualization\n", " mol = read_mol(input_pdb) # Read PDB file content\n", "\n", " # Prepare high-scoring residues script if scores are provided\n", " high_score_script = \"\"\n", " if residue_scores is not None:\n", " # Filter residues based on their scores\n", " class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2]\n", " class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4]\n", " class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6]\n", " class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]\n", " class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]\n", " \n", " high_score_script = \"\"\"\n", " // Load the original model and apply white cartoon style\n", " let chainModel = viewer.addModel(pdb, \"pdb\");\n", " chainModel.setStyle({}, {});\n", " chainModel.setStyle(\n", " {\"chain\": \"%s\"}, \n", " {\"cartoon\": {\"color\": \"white\"}}\n", " );\n", "\n", " // Create a new model for high-scoring residues and apply red sticks style\n", " let class1Model = viewer.addModel(pdb, \"pdb\");\n", " class1Model.setStyle({}, {});\n", " class1Model.setStyle(\n", " {\"chain\": \"%s\", \"resi\": [%s]}, \n", " {\"stick\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}}\n", " );\n", "\n", " // Create a new model for high-scoring residues and apply red sticks style\n", " let class2Model = viewer.addModel(pdb, \"pdb\");\n", " class2Model.setStyle({}, {});\n", " class2Model.setStyle(\n", " {\"chain\": \"%s\", \"resi\": [%s]}, \n", " {\"stick\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}}\n", " );\n", "\n", " // Create a new model for high-scoring residues and apply red sticks style\n", " let class3Model = viewer.addModel(pdb, \"pdb\");\n", " class3Model.setStyle({}, {});\n", " class3Model.setStyle(\n", " {\"chain\": \"%s\", \"resi\": [%s]}, \n", " {\"stick\": {\"color\": \"0xFFA500\", \"opacity\": 1}}\n", " );\n", "\n", " // Create a new model for high-scoring residues and apply red sticks style\n", " let class4Model = viewer.addModel(pdb, \"pdb\");\n", " class4Model.setStyle({}, {});\n", " class4Model.setStyle(\n", " {\"chain\": \"%s\", \"resi\": [%s]}, \n", " {\"stick\": {\"color\": \"0xFF4500\", \"opacity\": 1}}\n", " );\n", "\n", " // Create a new model for high-scoring residues and apply red sticks style\n", " let class5Model = viewer.addModel(pdb, \"pdb\");\n", " class5Model.setStyle({}, {});\n", " class5Model.setStyle(\n", " {\"chain\": \"%s\", \"resi\": [%s]}, \n", " {\"stick\": {\"color\": \"0xFF0000\", \"alpha\": 1}}\n", " );\n", "\n", " \"\"\" % (\n", " segment,\n", " segment,\n", " \", \".join(str(resi) for resi in class1_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in class2_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in class3_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in class4_score_residues),\n", " segment,\n", " \", \".join(str(resi) for resi in class5_score_residues)\n", " )\n", " \n", " # Generate the full HTML content\n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " # Return the HTML content within an iframe safely encoded for special characters\n", " return f''\n", "\n", "# Gradio UI\n", "with gr.Blocks(css=\"\"\"\n", " /* Customize Gradio button colors */\n", " #visualize-btn, #predict-btn {\n", " background-color: #FF7300; /* Deep orange */\n", " color: white;\n", " border-radius: 5px;\n", " padding: 10px;\n", " font-weight: bold;\n", " }\n", " #visualize-btn:hover, #predict-btn:hover {\n", " background-color: #CC5C00; /* Darkened orange on hover */\n", " }\n", "\"\"\") as demo:\n", " gr.Markdown(\"# Protein Binding Site Prediction\")\n", " \n", " # Mode selection\n", " mode = gr.Radio(\n", " choices=[\"PDB ID\", \"Upload File\"],\n", " value=\"PDB ID\",\n", " label=\"Input Mode\",\n", " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file.\"\n", " )\n", "\n", " # Input components based on mode\n", " pdb_input = gr.Textbox(value=\"2F6V\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n", " pdb_file = gr.File(label=\"Upload PDB/CIF File\", visible=False)\n", " visualize_btn = gr.Button(\"Visualize Structure\", elem_id=\"visualize-btn\")\n", "\n", " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n", " {\n", " \"model\": 0,\n", " \"style\": \"cartoon\",\n", " \"color\": \"whiteCarbon\",\n", " \"residue_range\": \"\",\n", " \"around\": 0,\n", " \"byres\": False,\n", " }\n", " ])\n", "\n", " with gr.Row():\n", " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID (protein)\", placeholder=\"Enter Chain ID here...\",\n", " info=\"Choose in which chain to predict binding sites.\")\n", " prediction_btn = gr.Button(\"Predict Binding Site\", elem_id=\"predict-btn\")\n", "\n", " molecule_output = gr.HTML(label=\"Protein Structure\")\n", " explanation_vis = gr.Markdown(\"\"\"\n", " Score dependent colorcoding:\n", " - 0.0-0.2: white \n", " - 0.2–0.4: light orange \n", " - 0.4–0.6: orange\n", " - 0.6–0.8: orangered\n", " - 0.8–1.0: red\n", " \"\"\")\n", " predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n", " gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n", " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n", " \n", " def process_interface(mode, pdb_id, pdb_file, chain_id):\n", " if mode == \"PDB ID\":\n", " return process_pdb(pdb_id, chain_id)\n", " elif mode == \"Upload File\":\n", " _, ext = os.path.splitext(pdb_file.name)\n", " file_path = os.path.join('./', f\"{_}{ext}\")\n", " if ext == '.cif':\n", " pdb_path = convert_cif_to_pdb(file_path)\n", " else:\n", " pdb_path= file_path\n", " return process_pdb(pdb_path, chain_id)\n", " else:\n", " return \"Error: Invalid mode selected\", None, None\n", "\n", " def fetch_interface(mode, pdb_id, pdb_file):\n", " if mode == \"PDB ID\":\n", " return fetch_pdb(pdb_id)\n", " elif mode == \"Upload File\":\n", " _, ext = os.path.splitext(pdb_file.name)\n", " file_path = os.path.join('./', f\"{_}{ext}\")\n", " #print(ext)\n", " if ext == '.cif':\n", " pdb_path = convert_cif_to_pdb(file_path)\n", " else:\n", " pdb_path= file_path\n", " #print(pdb_path)\n", " return pdb_path\n", " else:\n", " return \"Error: Invalid mode selected\"\n", "\n", " def toggle_mode(selected_mode):\n", " if selected_mode == \"PDB ID\":\n", " return gr.update(visible=True), gr.update(visible=False)\n", " else:\n", " return gr.update(visible=False), gr.update(visible=True)\n", "\n", " mode.change(\n", " toggle_mode,\n", " inputs=[mode],\n", " outputs=[pdb_input, pdb_file]\n", " )\n", "\n", " prediction_btn.click(\n", " process_interface, \n", " inputs=[mode, pdb_input, pdb_file, segment_input], \n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", " visualize_btn.click(\n", " fetch_interface, \n", " inputs=[mode, pdb_input, pdb_file], \n", " outputs=molecule_output2\n", " )\n", "\n", " gr.Markdown(\"## Examples\")\n", " gr.Examples(\n", " examples=[\n", " [\"7RPZ\", \"A\"],\n", " [\"2IWI\", \"B\"],\n", " [\"7LCJ\", \"R\"]\n", " ],\n", " inputs=[pdb_input, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output]\n", " )\n", "\n", "demo.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "440c87ed-45c9-4501-b208-409cbfd7858b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "id": "d70c40b9-5d5a-4795-b2a2-149c4a57d16e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py:441: UserWarning: Examples will be cached but not all input components have example values. This may result in an exception being thrown by your function. If you do get an error while caching examples, make sure all of your inputs have example values for all of your examples or you provide default values for those particular parameters in your function.\n", " warnings.warn(\n", "INFO:__main__:Using cached structure: ./7rpz.cif\n", "INFO:__main__:Using cached structure: ./2iwi.cif\n", "INFO:__main__:Using cached structure: ./2f6v.cif\n", "INFO:httpx:HTTP Request: GET http://127.0.0.1:7862/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7862\n", "Caching examples at: '/home/frohlkin/Projects/LargeLanguageModels/Publication/test_webpage/.gradio/cached_examples/148'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:httpx:HTTP Request: HEAD http://127.0.0.1:7862/ \"HTTP/1.1 200 OK\"\n", "INFO:httpx:HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n", "INFO:httpx:HTTP Request: GET https://api.gradio.app/v3/tunnel-request \"HTTP/1.1 200 OK\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* Running on public URL: https://de785d7cce806497e9.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:httpx:HTTP Request: HEAD https://de785d7cce806497e9.gradio.live \"HTTP/1.1 200 OK\"\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n", " output = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n", " prediction = await anyio.to_thread.run_sync( # type: ignore\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n", " ) + self.load_from_cache(example_id)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n", " output.append(component.read_from_flag(value_to_use))\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n", " return self.data_model.from_json(json.loads(payload))\n", " ^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n", " return _default_decoder.decode(s)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n", " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n", " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n", "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n", "Traceback (most recent call last):\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n", " output = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n", " prediction = await anyio.to_thread.run_sync( # type: ignore\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n", " ) + self.load_from_cache(example_id)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n", " output.append(component.read_from_flag(value_to_use))\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n", " return self.data_model.from_json(json.loads(payload))\n", " ^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n", " return _default_decoder.decode(s)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n", " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n", " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n", "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n", "Traceback (most recent call last):\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n", " output = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n", " prediction = await anyio.to_thread.run_sync( # type: ignore\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n", " ) + self.load_from_cache(example_id)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n", " output.append(component.read_from_flag(value_to_use))\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n", " return self.data_model.from_json(json.loads(payload))\n", " ^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n", " return _default_decoder.decode(s)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n", " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n", " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n", "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n" ] } ], "source": [ "from datetime import datetime\n", "import gradio as gr\n", "import requests\n", "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select, Structure\n", "from Bio.PDB.Polypeptide import is_aa\n", "from Bio.SeqUtils import seq1\n", "from typing import Optional, Tuple, Dict, List\n", "import numpy as np\n", "import os\n", "from gradio_molecule3d import Molecule3D\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "import re\n", "import pandas as pd\n", "import copy\n", "from scipy.special import expit\n", "import logging\n", "import tempfile\n", "\n", "# Set up logging\n", "logging.basicConfig(level=logging.INFO)\n", "logger = logging.getLogger(__name__)\n", "\n", "class StructureError(Exception):\n", " \"\"\"Custom exception for structure-related errors\"\"\"\n", " pass\n", "\n", "def normalize_scores(scores: np.ndarray) -> np.ndarray:\n", " \"\"\"Normalize scores to range [0,1]\"\"\"\n", " min_score = np.min(scores)\n", " max_score = np.max(scores)\n", " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n", "\n", "def read_mol(pdb_path: str) -> str:\n", " \"\"\"Read molecular structure file and return its content\"\"\"\n", " try:\n", " with open(pdb_path, 'r') as f:\n", " return f.read()\n", " except Exception as e:\n", " raise IOError(f\"Failed to read structure file: {e}\")\n", "\n", "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n", " \"\"\"Fetch structure file, trying multiple formats and sources\"\"\"\n", " try:\n", " # First try local cache\n", " for ext in ['.cif', '.pdb']:\n", " local_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n", " if os.path.exists(local_path):\n", " logger.info(f\"Using cached structure: {local_path}\")\n", " return local_path\n", "\n", " # Try different download sources\n", " sources = [\n", " f\"https://files.rcsb.org/download/{pdb_id.upper()}.cif\",\n", " f\"https://files.rcsb.org/download/{pdb_id.upper()}.pdb\",\n", " f\"https://files.rcsb.org/download/{pdb_id.lower()}.cif\",\n", " f\"https://files.rcsb.org/download/{pdb_id.lower()}.pdb\"\n", " ]\n", "\n", " for url in sources:\n", " try:\n", " response = requests.get(url, timeout=10)\n", " if response.status_code == 200:\n", " ext = '.cif' if 'cif' in url else '.pdb'\n", " file_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n", " with open(file_path, 'wb') as f:\n", " f.write(response.content)\n", " logger.info(f\"Successfully downloaded: {url}\")\n", " return file_path\n", " except Exception as e:\n", " logger.warning(f\"Failed to download from {url}: {e}\")\n", " continue\n", "\n", " raise StructureError(f\"Failed to fetch structure for PDB ID: {pdb_id}\")\n", " except Exception as e:\n", " raise StructureError(f\"Error fetching structure: {e}\")\n", "\n", "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n", " \"\"\"Convert CIF to PDB format with error handling\"\"\"\n", " try:\n", " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n", " parser = MMCIFParser(QUIET=True)\n", " structure = parser.get_structure('protein', cif_path)\n", " io = PDBIO()\n", " io.set_structure(structure)\n", " io.save(pdb_path)\n", " return pdb_path\n", " except Exception as e:\n", " raise StructureError(f\"Failed to convert CIF to PDB: {e}\")\n", "\n", "def find_valid_chain(structure: Structure.Structure) -> Optional[str]:\n", " \"\"\"Find the first valid protein chain in the structure\"\"\"\n", " for model in structure:\n", " for chain in model:\n", " protein_residues = [res for res in chain if is_aa(res)]\n", " if len(protein_residues) > 0:\n", " return chain.id\n", " return None\n", "\n", "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n", " \"\"\"Create PDB file with selected chain and prediction scores in B-factor column\"\"\"\n", " class ResidueSelector(Select):\n", " def __init__(self, chain_id, selected_residues, scores_dict):\n", " self.chain_id = chain_id\n", " self.selected_residues = selected_residues\n", " self.scores_dict = scores_dict\n", " \n", " def accept_chain(self, chain):\n", " return chain.id == self.chain_id\n", " \n", " def accept_residue(self, residue):\n", " return residue.id[1] in self.selected_residues\n", "\n", " def accept_atom(self, atom):\n", " if atom.parent.id[1] in self.scores_dict:\n", " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n", " return True\n", "\n", " try:\n", " parser = PDBParser(QUIET=True)\n", " structure = parser.get_structure('protein', input_pdb)\n", " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n", " scores_dict = {resi: score for resi, score in residue_scores}\n", " \n", " io = PDBIO()\n", " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n", " io.set_structure(structure[0])\n", " io.save(output_pdb, selector)\n", " \n", " return output_pdb\n", " except Exception as e:\n", " raise StructureError(f\"Failed to create chain-specific PDB: {e}\")\n", "\n", "def process_pdb(pdb_id_or_file: str, segment: str) -> Tuple[str, str, List[str]]:\n", " \"\"\"Process PDB/CIF file and generate visualizations and predictions\"\"\"\n", " try:\n", " # Handle input\n", " if pdb_id_or_file.endswith(('.pdb', '.cif')):\n", " pdb_path = pdb_id_or_file\n", " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n", " else:\n", " pdb_id = pdb_id_or_file\n", " pdb_path = fetch_structure(pdb_id)\n", "\n", " if not pdb_path:\n", " raise StructureError(\"Failed to obtain structure file\")\n", "\n", " # Parse structure\n", " parser = MMCIFParser(QUIET=True) if pdb_path.endswith('.cif') else PDBParser(QUIET=True)\n", " structure = parser.get_structure('protein', pdb_path)\n", "\n", " # Handle chain selection\n", " if segment == 'auto' or not segment:\n", " segment = find_valid_chain(structure)\n", " if not segment:\n", " raise StructureError(\"No valid protein chains found in structure\")\n", " \n", " try:\n", " chain = structure[0][segment]\n", " except KeyError:\n", " valid_chain = find_valid_chain(structure)\n", " if valid_chain:\n", " chain = structure[0][valid_chain]\n", " segment = valid_chain\n", " logger.info(f\"Using alternative chain {segment}\")\n", " else:\n", " raise StructureError(f\"Invalid chain ID '{segment}'. Structure has no valid protein chains.\")\n", "\n", " # Process chain\n", " protein_residues = [res for res in chain if is_aa(res)]\n", " if not protein_residues:\n", " raise StructureError(f\"No amino acid residues found in chain {segment}\")\n", "\n", " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n", " sequence_id = [res.id[1] for res in protein_residues]\n", " \n", " # Generate predictions (currently random)\n", " scores = np.random.rand(len(sequence))\n", " normalized_scores = normalize_scores(scores)\n", " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n", "\n", " # Generate outputs\n", " result_str = generate_results_string(pdb_id, segment, protein_residues, normalized_scores, sequence)\n", " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n", " mol_vis = molecule(pdb_path, residue_scores, segment)\n", " pymol_commands = generate_pymol_commands(pdb_id, segment, residue_scores, pdb_path)\n", "\n", " # Save results\n", " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n", " with open(prediction_file, \"w\") as f:\n", " f.write(result_str)\n", "\n", " return pymol_commands, mol_vis, [prediction_file, scored_pdb]\n", "\n", " except StructureError as e:\n", " return str(e), None, None\n", " except Exception as e:\n", " return f\"An unexpected error occurred: {str(e)}\", None, None\n", "\n", "def generate_results_string(pdb_id: str, segment: str, protein_residues: list, \n", " normalized_scores: np.ndarray, sequence: str) -> str:\n", " \"\"\"Generate formatted results string with predictions\"\"\"\n", " score_brackets = {\n", " \"0.0-0.2\": (0.0, 0.2),\n", " \"0.2-0.4\": (0.2, 0.4),\n", " \"0.4-0.6\": (0.4, 0.6),\n", " \"0.6-0.8\": (0.6, 0.8),\n", " \"0.8-1.0\": (0.8, 1.0)\n", " }\n", " \n", " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n", " \n", " # Categorize residues\n", " for i, score in enumerate(normalized_scores):\n", " for bracket, (lower, upper) in score_brackets.items():\n", " if lower <= score < upper:\n", " residues_by_bracket[bracket].append(protein_residues[i])\n", " break\n", " \n", " # Format results\n", " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " result_str = f\"\"\"Prediction Results\n", "========================\n", "PDB: {pdb_id}\n", "Chain: {segment}\n", "Date: {current_time}\n", "\n", "Analysis by Score Brackets\n", "========================\n", "\"\"\"\n", " \n", " for bracket, residues in residues_by_bracket.items():\n", " if residues: # Only show brackets with residues\n", " result_str += f\"\\nBracket {bracket}:\\n\"\n", " result_str += \"ResName ResNum Code Score\\n\"\n", " result_str += \"-\" * 30 + \"\\n\"\n", " result_str += \"\\n\".join([\n", " f\"{res.resname:6} {res.id[1]:6} {sequence[i]:4} {normalized_scores[i]:6.2f}\" \n", " for i, res in enumerate(protein_residues) if res in residues\n", " ])\n", " result_str += \"\\n\"\n", " \n", " return result_str\n", "\n", "def generate_pymol_commands(pdb_id: str, segment: str, residue_scores: list, pdb_path: str) -> str:\n", " \"\"\"Generate PyMOL visualization commands\"\"\"\n", " # Group residues by score ranges\n", " score_groups = {\n", " \"very_low\": [], \"low\": [], \"medium\": [], \"high\": [], \"very_high\": []\n", " }\n", " \n", " for resi, score in residue_scores:\n", " if score <= 0.2:\n", " score_groups[\"very_low\"].append(str(resi))\n", " elif score <= 0.4:\n", " score_groups[\"low\"].append(str(resi))\n", " elif score <= 0.6:\n", " score_groups[\"medium\"].append(str(resi))\n", " elif score <= 0.8:\n", " score_groups[\"high\"].append(str(resi))\n", " else:\n", " score_groups[\"very_high\"].append(str(resi))\n", "\n", " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " commands = f\"\"\"# PyMOL Script for {pdb_id} Chain {segment}\n", "# Generated: {current_time}\n", "\n", "# Load structure and set initial display\n", "load {os.path.abspath(pdb_path)}, protein\n", "bg_color white\n", "hide everything\n", "show cartoon, chain {segment}\n", "color white, chain {segment}\n", "\n", "# Create selection groups by score\n", "\"\"\"\n", " \n", " color_scheme = {\n", " \"very_low\": \"white\",\n", " \"low\": \"lightorange\",\n", " \"medium\": \"orange\",\n", " \"high\": \"orangered\",\n", " \"very_high\": \"red\"\n", " }\n", " \n", " for group, residues in score_groups.items():\n", " if residues:\n", " resi_str = \"+\".join(residues)\n", " commands += f\"\"\"\n", "# {group.replace('_', ' ').title()} scoring residues\n", "select {group}, chain {segment} and resi {resi_str}\n", "show sticks, {group}\n", "color {color_scheme[group]}, {group}\"\"\"\n", " \n", " commands += \"\"\"\n", "\n", "# Center and zoom\n", "center chain {}\n", "zoom chain {}\n", "\"\"\"\n", "\n", " return commands\n", "\n", "def molecule(input_pdb: str, residue_scores: list = None, segment: str = 'A') -> str:\n", " \"\"\"Generate interactive 3D molecule visualization\"\"\"\n", " try:\n", " mol = read_mol(input_pdb)\n", " except Exception as e:\n", " return f'
Error loading structure: {str(e)}
'\n", "\n", " # Prepare residue groups for visualization\n", " vis_groups = {\n", " \"class1\": [], # 0.0-0.2\n", " \"class2\": [], # 0.2-0.4\n", " \"class3\": [], # 0.4-0.6\n", " \"class4\": [], # 0.6-0.8\n", " \"class5\": [] # 0.8-1.0\n", " }\n", "\n", " if residue_scores:\n", " for resi, score in residue_scores:\n", " if score <= 0.2:\n", " vis_groups[\"class1\"].append(resi)\n", " elif score <= 0.4:\n", " vis_groups[\"class2\"].append(resi)\n", " elif score <= 0.6:\n", " vis_groups[\"class3\"].append(resi)\n", " elif score <= 0.8:\n", " vis_groups[\"class4\"].append(resi)\n", " else:\n", " vis_groups[\"class5\"].append(resi)\n", "\n", " # Generate visualization script\n", " vis_script = f\"\"\"\n", " // Base model setup\n", " let chainModel = viewer.addModel(pdb, \"pdb\");\n", " chainModel.setStyle({{}}, {{}});\n", " chainModel.setStyle(\n", " {{\"chain\": \"{segment}\"}}, \n", " {{\"cartoon\": {{\"color\": \"white\"}}}}\n", " );\n", " \"\"\"\n", "\n", " # Color schemes for different score ranges\n", " color_schemes = {\n", " \"class1\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}, # White\n", " \"class2\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}, # Light orange\n", " \"class3\": {\"color\": \"0xFFA500\", \"opacity\": 1.0}, # Orange\n", " \"class4\": {\"color\": \"0xFF4500\", \"opacity\": 1.0}, # Orange red\n", " \"class5\": {\"color\": \"0xFF0000\", \"opacity\": 1.0} # Red\n", " }\n", "\n", " # Add visualization for each group\n", " for group, residues in vis_groups.items():\n", " if residues:\n", " color_scheme = color_schemes[group]\n", " vis_script += f\"\"\"\n", " let {group}Model = viewer.addModel(pdb, \"pdb\");\n", " {group}Model.setStyle({{}}, {{}});\n", " {group}Model.setStyle(\n", " {{\"chain\": \"{segment}\", \"resi\": [{\", \".join(map(str, residues))}]}},\n", " {{\"stick\": {{\"color\": \"{color_scheme[\"color\"]}\", \"opacity\": {color_scheme[\"opacity\"]}}}}}\n", " );\n", " \"\"\"\n", "\n", " # Generate full HTML with enhanced controls and information\n", " html_content = f\"\"\"\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
Very High (0.8-1.0)
\n", "
High (0.6-0.8)
\n", "
Medium (0.4-0.6)
\n", "
Low (0.2-0.4)
\n", "
Very Low (0.0-0.2)
\n", "
\n", "
\n", " \n", " \n", " \n", " \"\"\"\n", " \n", " return f''\n", "\n", "# Gradio UI\n", "def create_ui():\n", " with gr.Blocks(title=\"Protein Binding Site Prediction\", theme=gr.themes.Base()) as demo:\n", " gr.Markdown(\"\"\"\n", " # Protein Binding Site Prediction\n", " \n", " This tool helps you visualize and analyze potential binding sites in protein structures.\n", " You can either:\n", " 1. Enter a PDB ID (e.g., \"4BDU\")\n", " 2. Upload your own PDB/CIF file\n", " \n", " The tool will analyze the structure and show predictions using a color gradient from white (low probability) to red (high probability).\n", " \"\"\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=2):\n", " # Input components\n", " mode = gr.Radio(\n", " choices=[\"PDB ID\", \"Upload File\"],\n", " value=\"PDB ID\",\n", " label=\"Input Mode\",\n", " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file\"\n", " )\n", " \n", " with gr.Group():\n", " pdb_input = gr.Textbox(\n", " value=\"4BDU\",\n", " label=\"PDB ID\",\n", " placeholder=\"Enter PDB ID (e.g., 4BDU)\",\n", " info=\"Enter a valid PDB ID from the Protein Data Bank\"\n", " )\n", " pdb_file = gr.File(\n", " label=\"Upload PDB/CIF File\",\n", " file_types=[\".pdb\", \".cif\"],\n", " visible=False\n", " )\n", " \n", " segment_input = gr.Textbox(\n", " value=\"A\",\n", " label=\"Chain ID\",\n", " placeholder=\"Enter Chain ID or leave empty for automatic selection\",\n", " info=\"Specify which protein chain to analyze, or leave empty for automatic selection\"\n", " )\n", "\n", " with gr.Column(scale=1):\n", " visualize_btn = gr.Button(\"Visualize Structure\", variant=\"primary\")\n", " prediction_btn = gr.Button(\"Predict Binding Site\", variant=\"secondary\")\n", " \n", " gr.Markdown(\"\"\"\n", " ### Color Legend\n", " - White: Very Low (0.0-0.2)\n", " - Light Orange: Low (0.2-0.4)\n", " - Orange: Medium (0.4-0.6)\n", " - Orange Red: High (0.6-0.8)\n", " - Red: Very High (0.8-1.0)\n", " \"\"\")\n", "\n", " with gr.Tab(\"3D Visualization\"):\n", " molecule_output = gr.HTML(label=\"Interactive 3D Structure\")\n", " \n", " with gr.Tab(\"Analysis Results\"):\n", " predictions_output = gr.Textbox(\n", " label=\"PyMOL Visualization Commands\",\n", " info=\"Copy these commands into PyMOL to recreate the visualization\"\n", " )\n", " download_output = gr.File(\n", " label=\"Download Results\",\n", " file_count=\"multiple\"\n", " )\n", "\n", " # Error message container\n", " error_output = gr.Markdown(visible=False)\n", "\n", " # Mode change handler\n", " def toggle_mode(selected_mode):\n", " return {\n", " pdb_input: gr.update(visible=selected_mode == \"PDB ID\"),\n", " pdb_file: gr.update(visible=selected_mode == \"Upload File\")\n", " }\n", "\n", " mode.change(\n", " toggle_mode,\n", " inputs=[mode],\n", " outputs=[pdb_input, pdb_file]\n", " )\n", "\n", " # Process handlers\n", " def handle_visualization(mode, pdb_id, pdb_file):\n", " try:\n", " result = fetch_interface(mode, pdb_id, pdb_file)\n", " if isinstance(result, str) and result.startswith(\"Error\"):\n", " return None, gr.update(visible=True, value=f\"```\\n{result}\\n```\")\n", " return result, gr.update(visible=False)\n", " except Exception as e:\n", " return None, gr.update(visible=True, value=f\"```\\nError: {str(e)}\\n```\")\n", "\n", " def handle_prediction(mode, pdb_id, pdb_file, chain_id):\n", " try:\n", " predictions, vis, downloads = process_interface(mode, pdb_id, pdb_file, chain_id)\n", " if isinstance(predictions, str) and \"Error\" in predictions:\n", " return (\n", " None,\n", " None,\n", " None,\n", " gr.update(visible=True, value=f\"```\\n{predictions}\\n```\")\n", " )\n", " return (\n", " predictions,\n", " vis,\n", " downloads,\n", " gr.update(visible=False)\n", " )\n", " except Exception as e:\n", " error_msg = f\"\"\"Error processing structure:\n", "```\n", "{str(e)}\n", "\n", "Troubleshooting tips:\n", "1. Check if the PDB ID is valid\n", "2. Ensure the Chain ID exists in the structure\n", "3. Try leaving Chain ID empty for automatic selection\n", "4. If uploading a file, ensure it's a valid PDB/CIF format\n", "```\"\"\"\n", " return None, None, None, gr.update(visible=True, value=error_msg)\n", "\n", " def fetch_interface(mode, pdb_id, pdb_file):\n", " try:\n", " if mode == \"PDB ID\":\n", " if not pdb_id or len(pdb_id.strip()) != 4:\n", " raise ValueError(\"Please enter a valid 4-character PDB ID\")\n", " return fetch_pdb(pdb_id.strip())\n", " elif mode == \"Upload File\":\n", " if not pdb_file:\n", " raise ValueError(\"Please upload a PDB or CIF file\")\n", " _, ext = os.path.splitext(pdb_file.name)\n", " if ext.lower() not in ['.pdb', '.cif']:\n", " raise ValueError(\"Only .pdb and .cif files are supported\")\n", " \n", " # Create temp directory for file handling\n", " with tempfile.TemporaryDirectory() as temp_dir:\n", " temp_path = os.path.join(temp_dir, os.path.basename(pdb_file.name))\n", " with open(temp_path, 'wb') as f:\n", " f.write(pdb_file.read())\n", " \n", " if ext.lower() == '.cif':\n", " return convert_cif_to_pdb(temp_path)\n", " return temp_path\n", " else:\n", " raise ValueError(\"Invalid mode selected\")\n", " except Exception as e:\n", " return f\"Error: {str(e)}\"\n", "\n", " # Connect event handlers\n", " visualize_btn.click(\n", " handle_visualization,\n", " inputs=[mode, pdb_input, pdb_file],\n", " outputs=[molecule_output, error_output]\n", " )\n", "\n", " prediction_btn.click(\n", " handle_prediction,\n", " inputs=[mode, pdb_input, pdb_file, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output, error_output]\n", " )\n", "\n", " # Add examples\n", " gr.Examples(\n", " examples=[\n", " [\"PDB ID\", \"7RPZ\", None, \"A\"],\n", " [\"PDB ID\", \"2IWI\", None, \"B\"],\n", " [\"PDB ID\", \"2F6V\", None, \"A\"]\n", " ],\n", " inputs=[mode, pdb_input, pdb_file, segment_input],\n", " outputs=[predictions_output, molecule_output, download_output, error_output],\n", " fn=handle_prediction,\n", " cache_examples=True\n", " )\n", "\n", " # Add documentation\n", " gr.Markdown(\"\"\"\n", " ## Usage Instructions\n", " \n", " 1. **Input Structure:**\n", " - Enter a PDB ID (e.g., \"4BDU\") or upload your own structure file\n", " - The tool supports both PDB (.pdb) and mmCIF (.cif) formats\n", " \n", " 2. **Select Chain:**\n", " - Enter a specific chain ID (e.g., \"A\")\n", " - Leave empty for automatic selection of the first valid protein chain\n", " \n", " 3. **Analyze:**\n", " - Click \"Visualize Structure\" to view the 3D structure\n", " - Click \"Predict Binding Site\" to perform binding site analysis\n", " \n", " 4. **Results:**\n", " - Interactive 3D visualization with color-coded predictions\n", " - PyMOL commands for external visualization\n", " - Downloadable results files\n", " \n", " ## Troubleshooting\n", " \n", " If you encounter issues:\n", " 1. Ensure your PDB ID is valid and exists in the PDB database\n", " 2. Check that your uploaded file is a valid PDB/CIF format\n", " 3. Try automatic chain selection if your specified chain isn't found\n", " 4. Clear your browser cache if visualizations don't load\n", " \"\"\")\n", "\n", " return demo\n", "\n", "if __name__ == \"__main__\":\n", " demo = create_ui()\n", " demo.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "9125d1c4-e2ae-4e40-ba36-7ae944512b8e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "85c0728a-a15b-4118-b920-5f55a2f5f79a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (LLM)", "language": "python", "name": "llm" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }