ThorbenFroehlking commited on
Commit
fd6cc24
·
1 Parent(s): e5b8e7f
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -82,6 +82,10 @@ def process_pdb(pdb_id, segment):
82
  for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
  )
 
 
 
 
85
 
86
  # Prepare input for model prediction
87
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
@@ -92,6 +96,9 @@ def process_pdb(pdb_id, segment):
92
  scores = expit(outputs[:, 1] - outputs[:, 0])
93
  normalized_scores = normalize_scores(scores)
94
 
 
 
 
95
  result_str = "\n".join([
96
  f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
97
  for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
@@ -102,14 +109,18 @@ def process_pdb(pdb_id, segment):
102
  with open(prediction_file, "w") as f:
103
  f.write(result_str)
104
 
105
- return result_str, molecule(pdb_path, normalized_scores, segment), prediction_file
106
 
107
- def molecule(input_pdb, scores=None, segment='A'):
108
  mol = read_mol(input_pdb) # Read PDB file content
109
 
110
  # Prepare high-scoring residues script if scores are provided
111
  high_score_script = ""
112
- if scores is not None:
 
 
 
 
113
  high_score_script = """
114
  // Reset all styles first
115
  viewer.getModel(0).setStyle({}, {});
@@ -127,16 +138,16 @@ def molecule(input_pdb, scores=None, segment='A'):
127
  {"stick": {"color": "red"}}
128
  );
129
 
130
- // Highlight high-scoring residues only for the selected chain
131
- let highScoreResidues2 = [%s];
132
  viewer.getModel(0).setStyle(
133
- {"chain": "%s", "resi": highScoreResidues2},
134
  {"stick": {"color": "orange"}}
135
  );
136
  """ % (segment,
137
- ", ".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),
138
  segment,
139
- ", ".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),
140
  segment)
141
 
142
  html_content = f"""
@@ -179,7 +190,7 @@ def molecule(input_pdb, scores=None, segment='A'):
179
  function(atom, viewer, event, container) {{
180
  if (!atom.label) {{
181
  atom.label = viewer.addLabel(
182
- atom.resn + ":" + atom.atom,
183
  {{
184
  position: atom,
185
  backgroundColor: 'mintcream',
@@ -246,8 +257,8 @@ with gr.Blocks() as demo:
246
  gr.Markdown("## Examples")
247
  gr.Examples(
248
  examples=[
249
- ["2IWI", "A"],
250
- ["7RPZ", "B"],
251
  ["3TJN", "C"]
252
  ],
253
  inputs=[pdb_input, segment_input],
 
82
  for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
  )
85
+ sequence2 = [
86
+ (res.id[1], res) for res in chain
87
+ if res.get_resname().strip() in aa_dict
88
+ ]
89
 
90
  # Prepare input for model prediction
91
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
 
96
  scores = expit(outputs[:, 1] - outputs[:, 0])
97
  normalized_scores = normalize_scores(scores)
98
 
99
+ # Zip residues with scores to track the residue ID and score
100
+ residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]
101
+
102
  result_str = "\n".join([
103
  f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
104
  for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
 
109
  with open(prediction_file, "w") as f:
110
  f.write(result_str)
111
 
112
+ return result_str, molecule(pdb_path, residue_scores, segment), prediction_file
113
 
114
+ def molecule(input_pdb, residue_scores=None, segment='A'):
115
  mol = read_mol(input_pdb) # Read PDB file content
116
 
117
  # Prepare high-scoring residues script if scores are provided
118
  high_score_script = ""
119
+ if residue_scores is not None:
120
+ # Sort residues based on their scores
121
+ high_score_residues = [resi for resi, score in residue_scores if score > 0.75]
122
+ mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]
123
+
124
  high_score_script = """
125
  // Reset all styles first
126
  viewer.getModel(0).setStyle({}, {});
 
138
  {"stick": {"color": "red"}}
139
  );
140
 
141
+ // Highlight medium-scoring residues only for the selected chain
142
+ let midScoreResidues = [%s];
143
  viewer.getModel(0).setStyle(
144
+ {"chain": "%s", "resi": midScoreResidues},
145
  {"stick": {"color": "orange"}}
146
  );
147
  """ % (segment,
148
+ ", ".join(str(resi) for resi in high_score_residues),
149
  segment,
150
+ ", ".join(str(resi) for resi in mid_score_residues),
151
  segment)
152
 
153
  html_content = f"""
 
190
  function(atom, viewer, event, container) {{
191
  if (!atom.label) {{
192
  atom.label = viewer.addLabel(
193
+ atom.resn + ":" +atom.resi + ":" + atom.atom,
194
  {{
195
  position: atom,
196
  backgroundColor: 'mintcream',
 
257
  gr.Markdown("## Examples")
258
  gr.Examples(
259
  examples=[
260
+ ["7RPZ", "A"],
261
+ ["2IWI", "B"],
262
  ["3TJN", "C"]
263
  ],
264
  inputs=[pdb_input, segment_input],
.ipynb_checkpoints/test2-checkpoint.ipynb CHANGED
@@ -473,7 +473,7 @@
473
  },
474
  {
475
  "cell_type": "code",
476
- "execution_count": 11,
477
  "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
  "metadata": {},
479
  "outputs": [
@@ -481,7 +481,7 @@
481
  "name": "stdout",
482
  "output_type": "stream",
483
  "text": [
484
- "* Running on local URL: http://127.0.0.1:7867\n",
485
  "\n",
486
  "To create a public link, set `share=True` in `launch()`.\n"
487
  ]
@@ -489,7 +489,7 @@
489
  {
490
  "data": {
491
  "text/html": [
492
- "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
  ],
494
  "text/plain": [
495
  "<IPython.core.display.HTML object>"
@@ -502,7 +502,7 @@
502
  "data": {
503
  "text/plain": []
504
  },
505
- "execution_count": 11,
506
  "metadata": {},
507
  "output_type": "execute_result"
508
  }
@@ -647,7 +647,7 @@
647
  " function(atom, viewer, event, container) {{\n",
648
  " if (!atom.label) {{\n",
649
  " atom.label = viewer.addLabel(\n",
650
- " atom.resn + \":\" + atom.atom, \n",
651
  " {{\n",
652
  " position: atom, \n",
653
  " backgroundColor: 'mintcream', \n",
@@ -727,16 +727,294 @@
727
  },
728
  {
729
  "cell_type": "code",
730
- "execution_count": null,
731
  "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  "outputs": [],
734
  "source": []
735
  },
736
  {
737
  "cell_type": "code",
738
  "execution_count": null,
739
- "id": "5eca6754-4aa1-463f-881a-25d2a0d6bb5b",
 
 
 
 
 
 
 
 
740
  "metadata": {},
741
  "outputs": [],
742
  "source": [
@@ -809,7 +1087,7 @@
809
  " except KeyError:\n",
810
  " return \"Invalid Chain ID\", None, None\n",
811
  " \n",
812
- " # Comprehensive amino acid mapping\n",
813
  " aa_dict = {\n",
814
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
815
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
@@ -819,9 +1097,14 @@
819
  " }\n",
820
  " \n",
821
  " # Exclude non-amino acid residues\n",
822
- " sequence = [\n",
823
- " residue for residue in chain \n",
 
824
  " if residue.get_resname().strip() in aa_dict\n",
 
 
 
 
825
  " ]\n",
826
  " \n",
827
  " # Prepare input for model prediction\n",
@@ -833,24 +1116,31 @@
833
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
834
  " normalized_scores = normalize_scores(scores)\n",
835
  "\n",
836
- " result_str = \"\\n\".join(\n",
837
- " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
838
- " for res, score in zip(sequence, normalized_scores)\n",
839
- " )\n",
 
 
 
840
  " \n",
841
  " # Save the predictions to a file\n",
842
  " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
843
  " with open(prediction_file, \"w\") as f:\n",
844
  " f.write(result_str)\n",
845
  " \n",
846
- " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
847
  "\n",
848
- "def molecule(input_pdb, scores=None, segment='A'):\n",
849
  " mol = read_mol(input_pdb) # Read PDB file content\n",
850
  " \n",
851
  " # Prepare high-scoring residues script if scores are provided\n",
852
  " high_score_script = \"\"\n",
853
- " if scores is not None:\n",
 
 
 
 
854
  " high_score_script = \"\"\"\n",
855
  " // Reset all styles first\n",
856
  " viewer.getModel(0).setStyle({}, {});\n",
@@ -868,16 +1158,16 @@
868
  " {\"stick\": {\"color\": \"red\"}}\n",
869
  " );\n",
870
  "\n",
871
- " // Highlight high-scoring residues only for the selected chain\n",
872
- " let highScoreResidues2 = [%s];\n",
873
  " viewer.getModel(0).setStyle(\n",
874
- " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
875
  " {\"stick\": {\"color\": \"orange\"}}\n",
876
  " );\n",
877
  " \"\"\" % (segment, \n",
878
- " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
879
  " segment,\n",
880
- " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
881
  " segment)\n",
882
  " \n",
883
  " html_content = f\"\"\"\n",
@@ -920,7 +1210,7 @@
920
  " function(atom, viewer, event, container) {{\n",
921
  " if (!atom.label) {{\n",
922
  " atom.label = viewer.addLabel(\n",
923
- " atom.resn + \":\" + atom.atom, \n",
924
  " {{\n",
925
  " position: atom, \n",
926
  " backgroundColor: 'mintcream', \n",
@@ -987,21 +1277,21 @@
987
  " gr.Markdown(\"## Examples\")\n",
988
  " gr.Examples(\n",
989
  " examples=[\n",
990
- " [\"2IWI\", \"A\"],\n",
991
- " [\"7RPZ\", \"B\"],\n",
992
  " [\"3TJN\", \"C\"]\n",
993
  " ],\n",
994
  " inputs=[pdb_input, segment_input],\n",
995
  " outputs=[predictions_output, molecule_output, download_output]\n",
996
  " )\n",
997
  "\n",
998
- "demo.launch()"
999
  ]
1000
  },
1001
  {
1002
  "cell_type": "code",
1003
  "execution_count": null,
1004
- "id": "95046d1c-ec7c-4e3e-8a98-1802cb09a25b",
1005
  "metadata": {},
1006
  "outputs": [],
1007
  "source": []
@@ -1009,11 +1299,18 @@
1009
  {
1010
  "cell_type": "code",
1011
  "execution_count": null,
1012
- "id": "a37cbe6f-d57f-41e5-8ae1-38258da39d47",
1013
  "metadata": {},
1014
  "outputs": [],
1015
  "source": [
1016
  "import gradio as gr\n",
 
 
 
 
 
 
 
1017
  "from model_loader import load_model\n",
1018
  "\n",
1019
  "import torch\n",
@@ -1022,8 +1319,6 @@
1022
  "from torch.utils.data import DataLoader\n",
1023
  "\n",
1024
  "import re\n",
1025
- "import numpy as np\n",
1026
- "import os\n",
1027
  "import pandas as pd\n",
1028
  "import copy\n",
1029
  "\n",
@@ -1035,18 +1330,6 @@
1035
  "\n",
1036
  "from scipy.special import expit\n",
1037
  "\n",
1038
- "import requests\n",
1039
- "\n",
1040
- "from gradio_molecule3d import Molecule3D\n",
1041
- "\n",
1042
- "# Biopython imports\n",
1043
- "from Bio.PDB import PDBParser, Select, PDBIO\n",
1044
- "from Bio.PDB.DSSP import DSSP\n",
1045
- "from Bio.PDB import PDBList\n",
1046
- "\n",
1047
- "from matplotlib import cm # For color mapping\n",
1048
- "from matplotlib.colors import Normalize\n",
1049
- "\n",
1050
  "# Load model and move to device\n",
1051
  "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1052
  "max_length = 1500\n",
@@ -1055,23 +1338,26 @@
1055
  "model.to(device)\n",
1056
  "model.eval()\n",
1057
  "\n",
1058
- "# Function to fetch a PDB file\n",
 
 
 
 
 
 
 
 
 
1059
  "def fetch_pdb(pdb_id):\n",
1060
  " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1061
- " pdb_path = f'pdb_files/{pdb_id}.pdb'\n",
1062
- " os.makedirs('pdb_files', exist_ok=True)\n",
1063
  " response = requests.get(pdb_url)\n",
1064
  " if response.status_code == 200:\n",
1065
  " with open(pdb_path, 'wb') as f:\n",
1066
  " f.write(response.content)\n",
1067
  " return pdb_path\n",
1068
- " return None\n",
1069
- "\n",
1070
- "\n",
1071
- "def normalize_scores(scores):\n",
1072
- " min_score = np.min(scores)\n",
1073
- " max_score = np.max(scores)\n",
1074
- " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1075
  "\n",
1076
  "def process_pdb(pdb_id, segment):\n",
1077
  " pdb_path = fetch_pdb(pdb_id)\n",
@@ -1080,9 +1366,13 @@
1080
  " \n",
1081
  " parser = PDBParser(QUIET=1)\n",
1082
  " structure = parser.get_structure('protein', pdb_path)\n",
1083
- " chain = structure[0][segment]\n",
1084
  " \n",
1085
- " # Comprehensive amino acid mapping\n",
 
 
 
 
 
1086
  " aa_dict = {\n",
1087
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1088
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
@@ -1106,67 +1396,171 @@
1106
  " # Calculate scores and normalize them\n",
1107
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1108
  " normalized_scores = normalize_scores(scores)\n",
1109
- " \n",
1110
- " # Prepare the result string, including only amino acid residues\n",
1111
  " result_str = \"\\n\".join([\n",
1112
  " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1113
  " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1114
  " ])\n",
1115
  " \n",
1116
- " # Save predictions to file\n",
1117
- " with open(f\"{pdb_id}_predictions.txt\", \"w\") as f:\n",
 
1118
  " f.write(result_str)\n",
1119
  " \n",
1120
- " return result_str, pdb_path, f\"{pdb_id}_predictions.txt\"\n",
1121
  "\n",
1122
- "reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
  "\n",
1124
  "# Gradio UI\n",
1125
  "with gr.Blocks() as demo:\n",
1126
- " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
 
 
 
 
 
1127
  "\n",
1128
  " with gr.Row():\n",
1129
- " pdb_input = gr.Textbox(value=\"2IWI\",\n",
1130
- " label=\"PDB ID\",\n",
1131
- " placeholder=\"Enter PDB ID here...\")\n",
1132
- " segment_input = gr.Textbox(value=\"A\",\n",
1133
- " label=\"Chain ID (Segment)\",\n",
1134
- " placeholder=\"Enter Chain ID here...\")\n",
1135
- " visualize_btn = gr.Button(\"Visualize Sructure\")\n",
1136
- " prediction_btn = gr.Button(\"Predict Ligand Binding Site\")\n",
1137
- "\n",
1138
- " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1139
  " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1140
  " download_output = gr.File(label=\"Download Predictions\")\n",
1141
- "\n",
1142
- " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
1143
- " prediction_btn.click(\n",
1144
- " process_pdb, \n",
1145
- " inputs=[pdb_input, segment_input], \n",
1146
- " outputs=[predictions_output, molecule_output, download_output]\n",
1147
- " )\n",
1148
- "\n",
1149
  " gr.Markdown(\"## Examples\")\n",
1150
  " gr.Examples(\n",
1151
  " examples=[\n",
1152
- " [\"2IWI\"],\n",
1153
- " [\"7RPZ\"],\n",
1154
- " [\"3TJN\"]\n",
1155
  " ],\n",
1156
- " inputs=[pdb_input, segment_input], \n",
1157
  " outputs=[predictions_output, molecule_output, download_output]\n",
1158
  " )\n",
1159
  "\n",
1160
  "demo.launch(share=True)"
1161
  ]
1162
- },
1163
- {
1164
- "cell_type": "code",
1165
- "execution_count": null,
1166
- "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1167
- "metadata": {},
1168
- "outputs": [],
1169
- "source": []
1170
  }
1171
  ],
1172
  "metadata": {
 
473
  },
474
  {
475
  "cell_type": "code",
476
+ "execution_count": 1,
477
  "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
  "metadata": {},
479
  "outputs": [
 
481
  "name": "stdout",
482
  "output_type": "stream",
483
  "text": [
484
+ "* Running on local URL: http://127.0.0.1:7860\n",
485
  "\n",
486
  "To create a public link, set `share=True` in `launch()`.\n"
487
  ]
 
489
  {
490
  "data": {
491
  "text/html": [
492
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
  ],
494
  "text/plain": [
495
  "<IPython.core.display.HTML object>"
 
502
  "data": {
503
  "text/plain": []
504
  },
505
+ "execution_count": 1,
506
  "metadata": {},
507
  "output_type": "execute_result"
508
  }
 
647
  " function(atom, viewer, event, container) {{\n",
648
  " if (!atom.label) {{\n",
649
  " atom.label = viewer.addLabel(\n",
650
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
651
  " {{\n",
652
  " position: atom, \n",
653
  " backgroundColor: 'mintcream', \n",
 
727
  },
728
  {
729
  "cell_type": "code",
730
+ "execution_count": 4,
731
  "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
  "metadata": {},
733
+ "outputs": [
734
+ {
735
+ "name": "stdout",
736
+ "output_type": "stream",
737
+ "text": [
738
+ "* Running on local URL: http://127.0.0.1:7863\n",
739
+ "\n",
740
+ "To create a public link, set `share=True` in `launch()`.\n"
741
+ ]
742
+ },
743
+ {
744
+ "data": {
745
+ "text/html": [
746
+ "<div><iframe src=\"http://127.0.0.1:7863/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
747
+ ],
748
+ "text/plain": [
749
+ "<IPython.core.display.HTML object>"
750
+ ]
751
+ },
752
+ "metadata": {},
753
+ "output_type": "display_data"
754
+ },
755
+ {
756
+ "data": {
757
+ "text/plain": []
758
+ },
759
+ "execution_count": 4,
760
+ "metadata": {},
761
+ "output_type": "execute_result"
762
+ }
763
+ ],
764
+ "source": [
765
+ "import gradio as gr\n",
766
+ "import requests\n",
767
+ "from Bio.PDB import PDBParser\n",
768
+ "import numpy as np\n",
769
+ "import os\n",
770
+ "from gradio_molecule3d import Molecule3D\n",
771
+ "\n",
772
+ "def read_mol(pdb_path):\n",
773
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
774
+ " with open(pdb_path, 'r') as f:\n",
775
+ " return f.read()\n",
776
+ "\n",
777
+ "def fetch_pdb(pdb_id):\n",
778
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
779
+ " pdb_path = f'{pdb_id}.pdb'\n",
780
+ " response = requests.get(pdb_url)\n",
781
+ " if response.status_code == 200:\n",
782
+ " with open(pdb_path, 'wb') as f:\n",
783
+ " f.write(response.content)\n",
784
+ " return pdb_path\n",
785
+ " else:\n",
786
+ " return None\n",
787
+ "\n",
788
+ "def process_pdb(pdb_id, segment):\n",
789
+ " pdb_path = fetch_pdb(pdb_id)\n",
790
+ " if not pdb_path:\n",
791
+ " return \"Failed to fetch PDB file\", None, None\n",
792
+ " \n",
793
+ " parser = PDBParser(QUIET=1)\n",
794
+ " structure = parser.get_structure('protein', pdb_path)\n",
795
+ " \n",
796
+ " try:\n",
797
+ " chain = structure[0][segment]\n",
798
+ " except KeyError:\n",
799
+ " return \"Invalid Chain ID\", None, None\n",
800
+ " \n",
801
+ " # Comprehensive amino acid mapping\n",
802
+ " aa_dict = {\n",
803
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
804
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
805
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
806
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
807
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
808
+ " }\n",
809
+ " \n",
810
+ " # Exclude non-amino acid residues and create a list of (resi, score) pairs\n",
811
+ " sequence = [\n",
812
+ " (res.id[1], res) for res in chain\n",
813
+ " if res.get_resname().strip() in aa_dict\n",
814
+ " ]\n",
815
+ "\n",
816
+ " random_scores = np.random.rand(len(sequence))\n",
817
+ " \n",
818
+ " # Zip residues with scores to track the residue ID and score\n",
819
+ " residue_scores = [(resi, score) for (resi, _), score in zip(sequence, random_scores)]\n",
820
+ " \n",
821
+ " result_str = \"\\n\".join(\n",
822
+ " f\"{aa_dict[chain[resi].get_resname()]} {resi} {score:.2f}\"\n",
823
+ " for resi, score in residue_scores\n",
824
+ " )\n",
825
+ " \n",
826
+ " # Save the predictions to a file\n",
827
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
828
+ " with open(prediction_file, \"w\") as f:\n",
829
+ " f.write(result_str)\n",
830
+ " \n",
831
+ " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n",
832
+ "\n",
833
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
834
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
835
+ " \n",
836
+ " # Prepare high-scoring residues script if scores are provided\n",
837
+ " high_score_script = \"\"\n",
838
+ " if residue_scores is not None:\n",
839
+ " # Sort residues based on their scores\n",
840
+ " high_score_residues = [resi for resi, score in residue_scores if score > 0.9]\n",
841
+ " mid_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n",
842
+ " \n",
843
+ " high_score_script = \"\"\"\n",
844
+ " // Reset all styles first\n",
845
+ " viewer.getModel(0).setStyle({}, {});\n",
846
+ " \n",
847
+ " // Show only the selected chain\n",
848
+ " viewer.getModel(0).setStyle(\n",
849
+ " {\"chain\": \"%s\"}, \n",
850
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
851
+ " );\n",
852
+ " \n",
853
+ " // Highlight high-scoring residues only for the selected chain\n",
854
+ " let highScoreResidues = [%s];\n",
855
+ " viewer.getModel(0).setStyle(\n",
856
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
857
+ " {\"stick\": {\"color\": \"red\"}}\n",
858
+ " );\n",
859
+ "\n",
860
+ " // Highlight medium-scoring residues only for the selected chain\n",
861
+ " let midScoreResidues = [%s];\n",
862
+ " viewer.getModel(0).setStyle(\n",
863
+ " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n",
864
+ " {\"stick\": {\"color\": \"orange\"}}\n",
865
+ " );\n",
866
+ " \"\"\" % (segment, \n",
867
+ " \", \".join(str(resi) for resi in high_score_residues),\n",
868
+ " segment,\n",
869
+ " \", \".join(str(resi) for resi in mid_score_residues),\n",
870
+ " segment)\n",
871
+ " \n",
872
+ " html_content = f\"\"\"\n",
873
+ " <!DOCTYPE html>\n",
874
+ " <html>\n",
875
+ " <head> \n",
876
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
877
+ " <style>\n",
878
+ " .mol-container {{\n",
879
+ " width: 100%;\n",
880
+ " height: 700px;\n",
881
+ " position: relative;\n",
882
+ " }}\n",
883
+ " </style>\n",
884
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
885
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
886
+ " </head>\n",
887
+ " <body>\n",
888
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
889
+ " <script>\n",
890
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
891
+ " $(document).ready(function () {{\n",
892
+ " let element = $(\"#container\");\n",
893
+ " let config = {{ backgroundColor: \"white\" }};\n",
894
+ " let viewer = $3Dmol.createViewer(element, config);\n",
895
+ " viewer.addModel(pdb, \"pdb\");\n",
896
+ " \n",
897
+ " // Reset all styles and show only selected chain\n",
898
+ " viewer.getModel(0).setStyle(\n",
899
+ " {{\"chain\": \"{segment}\"}}, \n",
900
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
901
+ " );\n",
902
+ " \n",
903
+ " {high_score_script}\n",
904
+ " \n",
905
+ " // Add hover functionality\n",
906
+ " viewer.setHoverable(\n",
907
+ " {{}}, \n",
908
+ " true, \n",
909
+ " function(atom, viewer, event, container) {{\n",
910
+ " if (!atom.label) {{\n",
911
+ " atom.label = viewer.addLabel(\n",
912
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
913
+ " {{\n",
914
+ " position: atom, \n",
915
+ " backgroundColor: 'mintcream', \n",
916
+ " fontColor: 'black',\n",
917
+ " fontSize: 12,\n",
918
+ " padding: 2\n",
919
+ " }}\n",
920
+ " );\n",
921
+ " }}\n",
922
+ " }},\n",
923
+ " function(atom, viewer) {{\n",
924
+ " if (atom.label) {{\n",
925
+ " viewer.removeLabel(atom.label);\n",
926
+ " delete atom.label;\n",
927
+ " }}\n",
928
+ " }}\n",
929
+ " );\n",
930
+ " \n",
931
+ " viewer.zoomTo();\n",
932
+ " viewer.render();\n",
933
+ " viewer.zoom(0.8, 2000);\n",
934
+ " }});\n",
935
+ " </script>\n",
936
+ " </body>\n",
937
+ " </html>\n",
938
+ " \"\"\"\n",
939
+ " \n",
940
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
941
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
942
+ "\n",
943
+ "reps = [\n",
944
+ " {\n",
945
+ " \"model\": 0,\n",
946
+ " \"style\": \"cartoon\",\n",
947
+ " \"color\": \"whiteCarbon\",\n",
948
+ " \"residue_range\": \"\",\n",
949
+ " \"around\": 0,\n",
950
+ " \"byres\": False,\n",
951
+ " }\n",
952
+ " ]\n",
953
+ "\n",
954
+ "# Gradio UI\n",
955
+ "with gr.Blocks() as demo:\n",
956
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
957
+ " with gr.Row():\n",
958
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
959
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
960
+ "\n",
961
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
962
+ "\n",
963
+ " with gr.Row():\n",
964
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
965
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
966
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
967
+ "\n",
968
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
969
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
970
+ " download_output = gr.File(label=\"Download Predictions\")\n",
971
+ " \n",
972
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
973
+ " \n",
974
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
975
+ " \n",
976
+ " gr.Markdown(\"## Examples\")\n",
977
+ " gr.Examples(\n",
978
+ " examples=[\n",
979
+ " [\"2IWI\", \"A\"],\n",
980
+ " [\"7RPZ\", \"B\"],\n",
981
+ " [\"3TJN\", \"C\"]\n",
982
+ " ],\n",
983
+ " inputs=[pdb_input, segment_input],\n",
984
+ " outputs=[predictions_output, molecule_output, download_output]\n",
985
+ " )\n",
986
+ "\n",
987
+ "demo.launch()"
988
+ ]
989
+ },
990
+ {
991
+ "cell_type": "code",
992
+ "execution_count": null,
993
+ "id": "6f17feec-0347-4f9d-acd4-ae681c3ed425",
994
+ "metadata": {},
995
+ "outputs": [],
996
+ "source": []
997
+ },
998
+ {
999
+ "cell_type": "code",
1000
+ "execution_count": null,
1001
+ "id": "63201f38-adde-4b12-a8d3-f23474d045cf",
1002
+ "metadata": {},
1003
  "outputs": [],
1004
  "source": []
1005
  },
1006
  {
1007
  "cell_type": "code",
1008
  "execution_count": null,
1009
+ "id": "5ccbf398-5ef2-4955-98db-99f904f8daa4",
1010
+ "metadata": {},
1011
+ "outputs": [],
1012
+ "source": []
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": null,
1017
+ "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1018
  "metadata": {},
1019
  "outputs": [],
1020
  "source": [
 
1087
  " except KeyError:\n",
1088
  " return \"Invalid Chain ID\", None, None\n",
1089
  " \n",
1090
+ " \n",
1091
  " aa_dict = {\n",
1092
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1093
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
 
1097
  " }\n",
1098
  " \n",
1099
  " # Exclude non-amino acid residues\n",
1100
+ " sequence = \"\".join(\n",
1101
+ " aa_dict[residue.get_resname().strip()] \n",
1102
+ " for residue in chain \n",
1103
  " if residue.get_resname().strip() in aa_dict\n",
1104
+ " )\n",
1105
+ " sequence2 = [\n",
1106
+ " (res.id[1], res) for res in chain\n",
1107
+ " if res.get_resname().strip() in aa_dict\n",
1108
  " ]\n",
1109
  " \n",
1110
  " # Prepare input for model prediction\n",
 
1116
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1117
  " normalized_scores = normalize_scores(scores)\n",
1118
  "\n",
1119
+ " # Zip residues with scores to track the residue ID and score\n",
1120
+ " residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]\n",
1121
+ " \n",
1122
+ " result_str = \"\\n\".join([\n",
1123
+ " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1124
+ " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1125
+ " ])\n",
1126
  " \n",
1127
  " # Save the predictions to a file\n",
1128
  " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
1129
  " with open(prediction_file, \"w\") as f:\n",
1130
  " f.write(result_str)\n",
1131
  " \n",
1132
+ " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n",
1133
  "\n",
1134
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
1135
  " mol = read_mol(input_pdb) # Read PDB file content\n",
1136
  " \n",
1137
  " # Prepare high-scoring residues script if scores are provided\n",
1138
  " high_score_script = \"\"\n",
1139
+ " if residue_scores is not None:\n",
1140
+ " # Sort residues based on their scores\n",
1141
+ " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
1142
+ " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
1143
+ " \n",
1144
  " high_score_script = \"\"\"\n",
1145
  " // Reset all styles first\n",
1146
  " viewer.getModel(0).setStyle({}, {});\n",
 
1158
  " {\"stick\": {\"color\": \"red\"}}\n",
1159
  " );\n",
1160
  "\n",
1161
+ " // Highlight medium-scoring residues only for the selected chain\n",
1162
+ " let midScoreResidues = [%s];\n",
1163
  " viewer.getModel(0).setStyle(\n",
1164
+ " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n",
1165
  " {\"stick\": {\"color\": \"orange\"}}\n",
1166
  " );\n",
1167
  " \"\"\" % (segment, \n",
1168
+ " \", \".join(str(resi) for resi in high_score_residues),\n",
1169
  " segment,\n",
1170
+ " \", \".join(str(resi) for resi in mid_score_residues),\n",
1171
  " segment)\n",
1172
  " \n",
1173
  " html_content = f\"\"\"\n",
 
1210
  " function(atom, viewer, event, container) {{\n",
1211
  " if (!atom.label) {{\n",
1212
  " atom.label = viewer.addLabel(\n",
1213
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
1214
  " {{\n",
1215
  " position: atom, \n",
1216
  " backgroundColor: 'mintcream', \n",
 
1277
  " gr.Markdown(\"## Examples\")\n",
1278
  " gr.Examples(\n",
1279
  " examples=[\n",
1280
+ " [\"7RPZ\", \"A\"],\n",
1281
+ " [\"2IWI\", \"B\"],\n",
1282
  " [\"3TJN\", \"C\"]\n",
1283
  " ],\n",
1284
  " inputs=[pdb_input, segment_input],\n",
1285
  " outputs=[predictions_output, molecule_output, download_output]\n",
1286
  " )\n",
1287
  "\n",
1288
+ "demo.launch(share=True)"
1289
  ]
1290
  },
1291
  {
1292
  "cell_type": "code",
1293
  "execution_count": null,
1294
+ "id": "b61d06ec-a4ee-4f65-925f-d2688730416a",
1295
  "metadata": {},
1296
  "outputs": [],
1297
  "source": []
 
1299
  {
1300
  "cell_type": "code",
1301
  "execution_count": null,
1302
+ "id": "4d67d69f-1f53-4bcc-8905-8d29384c4e20",
1303
  "metadata": {},
1304
  "outputs": [],
1305
  "source": [
1306
  "import gradio as gr\n",
1307
+ "import requests\n",
1308
+ "from Bio.PDB import PDBParser\n",
1309
+ "import numpy as np\n",
1310
+ "import os\n",
1311
+ "from gradio_molecule3d import Molecule3D\n",
1312
+ "\n",
1313
+ "\n",
1314
  "from model_loader import load_model\n",
1315
  "\n",
1316
  "import torch\n",
 
1319
  "from torch.utils.data import DataLoader\n",
1320
  "\n",
1321
  "import re\n",
 
 
1322
  "import pandas as pd\n",
1323
  "import copy\n",
1324
  "\n",
 
1330
  "\n",
1331
  "from scipy.special import expit\n",
1332
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
1333
  "# Load model and move to device\n",
1334
  "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1335
  "max_length = 1500\n",
 
1338
  "model.to(device)\n",
1339
  "model.eval()\n",
1340
  "\n",
1341
+ "def normalize_scores(scores):\n",
1342
+ " min_score = np.min(scores)\n",
1343
+ " max_score = np.max(scores)\n",
1344
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1345
+ " \n",
1346
+ "def read_mol(pdb_path):\n",
1347
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
1348
+ " with open(pdb_path, 'r') as f:\n",
1349
+ " return f.read()\n",
1350
+ "\n",
1351
  "def fetch_pdb(pdb_id):\n",
1352
  " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1353
+ " pdb_path = f'{pdb_id}.pdb'\n",
 
1354
  " response = requests.get(pdb_url)\n",
1355
  " if response.status_code == 200:\n",
1356
  " with open(pdb_path, 'wb') as f:\n",
1357
  " f.write(response.content)\n",
1358
  " return pdb_path\n",
1359
+ " else:\n",
1360
+ " return None\n",
 
 
 
 
 
1361
  "\n",
1362
  "def process_pdb(pdb_id, segment):\n",
1363
  " pdb_path = fetch_pdb(pdb_id)\n",
 
1366
  " \n",
1367
  " parser = PDBParser(QUIET=1)\n",
1368
  " structure = parser.get_structure('protein', pdb_path)\n",
 
1369
  " \n",
1370
+ " try:\n",
1371
+ " chain = structure[0][segment]\n",
1372
+ " except KeyError:\n",
1373
+ " return \"Invalid Chain ID\", None, None\n",
1374
+ " \n",
1375
+ " \n",
1376
  " aa_dict = {\n",
1377
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1378
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
 
1396
  " # Calculate scores and normalize them\n",
1397
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1398
  " normalized_scores = normalize_scores(scores)\n",
1399
+ "\n",
 
1400
  " result_str = \"\\n\".join([\n",
1401
  " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1402
  " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1403
  " ])\n",
1404
  " \n",
1405
+ " # Save the predictions to a file\n",
1406
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
1407
+ " with open(prediction_file, \"w\") as f:\n",
1408
  " f.write(result_str)\n",
1409
  " \n",
1410
+ " return result_str, molecule(pdb_path, normalized_scores, segment), prediction_file\n",
1411
  "\n",
1412
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
1413
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
1414
+ " \n",
1415
+ " # Prepare high-scoring residues script if scores are provided\n",
1416
+ " high_score_script = \"\"\n",
1417
+ " if scores is not None:\n",
1418
+ " high_score_script = \"\"\"\n",
1419
+ " // Reset all styles first\n",
1420
+ " viewer.getModel(0).setStyle({}, {});\n",
1421
+ " \n",
1422
+ " // Show only the selected chain\n",
1423
+ " viewer.getModel(0).setStyle(\n",
1424
+ " {\"chain\": \"%s\"}, \n",
1425
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
1426
+ " );\n",
1427
+ " \n",
1428
+ " // Highlight high-scoring residues only for the selected chain\n",
1429
+ " let highScoreResidues = [%s];\n",
1430
+ " viewer.getModel(0).setStyle(\n",
1431
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
1432
+ " {\"stick\": {\"color\": \"red\"}}\n",
1433
+ " );\n",
1434
+ "\n",
1435
+ " // Highlight high-scoring residues only for the selected chain\n",
1436
+ " let highScoreResidues2 = [%s];\n",
1437
+ " viewer.getModel(0).setStyle(\n",
1438
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
1439
+ " {\"stick\": {\"color\": \"orange\"}}\n",
1440
+ " );\n",
1441
+ " \"\"\" % (segment, \n",
1442
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
1443
+ " segment,\n",
1444
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
1445
+ " segment)\n",
1446
+ " \n",
1447
+ " html_content = f\"\"\"\n",
1448
+ " <!DOCTYPE html>\n",
1449
+ " <html>\n",
1450
+ " <head> \n",
1451
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
1452
+ " <style>\n",
1453
+ " .mol-container {{\n",
1454
+ " width: 100%;\n",
1455
+ " height: 700px;\n",
1456
+ " position: relative;\n",
1457
+ " }}\n",
1458
+ " </style>\n",
1459
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
1460
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
1461
+ " </head>\n",
1462
+ " <body>\n",
1463
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
1464
+ " <script>\n",
1465
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
1466
+ " $(document).ready(function () {{\n",
1467
+ " let element = $(\"#container\");\n",
1468
+ " let config = {{ backgroundColor: \"white\" }};\n",
1469
+ " let viewer = $3Dmol.createViewer(element, config);\n",
1470
+ " viewer.addModel(pdb, \"pdb\");\n",
1471
+ " \n",
1472
+ " // Reset all styles and show only selected chain\n",
1473
+ " viewer.getModel(0).setStyle(\n",
1474
+ " {{\"chain\": \"{segment}\"}}, \n",
1475
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
1476
+ " );\n",
1477
+ " \n",
1478
+ " {high_score_script}\n",
1479
+ " \n",
1480
+ " // Add hover functionality\n",
1481
+ " viewer.setHoverable(\n",
1482
+ " {{}}, \n",
1483
+ " true, \n",
1484
+ " function(atom, viewer, event, container) {{\n",
1485
+ " if (!atom.label) {{\n",
1486
+ " atom.label = viewer.addLabel(\n",
1487
+ " atom.resn + \":\" + atom.atom, \n",
1488
+ " {{\n",
1489
+ " position: atom, \n",
1490
+ " backgroundColor: 'mintcream', \n",
1491
+ " fontColor: 'black',\n",
1492
+ " fontSize: 12,\n",
1493
+ " padding: 2\n",
1494
+ " }}\n",
1495
+ " );\n",
1496
+ " }}\n",
1497
+ " }},\n",
1498
+ " function(atom, viewer) {{\n",
1499
+ " if (atom.label) {{\n",
1500
+ " viewer.removeLabel(atom.label);\n",
1501
+ " delete atom.label;\n",
1502
+ " }}\n",
1503
+ " }}\n",
1504
+ " );\n",
1505
+ " \n",
1506
+ " viewer.zoomTo();\n",
1507
+ " viewer.render();\n",
1508
+ " viewer.zoom(0.8, 2000);\n",
1509
+ " }});\n",
1510
+ " </script>\n",
1511
+ " </body>\n",
1512
+ " </html>\n",
1513
+ " \"\"\"\n",
1514
+ " \n",
1515
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
1516
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
1517
+ "\n",
1518
+ "reps = [\n",
1519
+ " {\n",
1520
+ " \"model\": 0,\n",
1521
+ " \"style\": \"cartoon\",\n",
1522
+ " \"color\": \"whiteCarbon\",\n",
1523
+ " \"residue_range\": \"\",\n",
1524
+ " \"around\": 0,\n",
1525
+ " \"byres\": False,\n",
1526
+ " }\n",
1527
+ " ]\n",
1528
  "\n",
1529
  "# Gradio UI\n",
1530
  "with gr.Blocks() as demo:\n",
1531
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
1532
+ " with gr.Row():\n",
1533
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
1534
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
1535
+ "\n",
1536
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1537
  "\n",
1538
  " with gr.Row():\n",
1539
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
1540
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
1541
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
1542
+ "\n",
1543
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
 
 
 
 
 
1544
  " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1545
  " download_output = gr.File(label=\"Download Predictions\")\n",
1546
+ " \n",
1547
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
1548
+ " \n",
1549
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
1550
+ " \n",
 
 
 
1551
  " gr.Markdown(\"## Examples\")\n",
1552
  " gr.Examples(\n",
1553
  " examples=[\n",
1554
+ " [\"2IWI\", \"A\"],\n",
1555
+ " [\"7RPZ\", \"B\"],\n",
1556
+ " [\"3TJN\", \"C\"]\n",
1557
  " ],\n",
1558
+ " inputs=[pdb_input, segment_input],\n",
1559
  " outputs=[predictions_output, molecule_output, download_output]\n",
1560
  " )\n",
1561
  "\n",
1562
  "demo.launch(share=True)"
1563
  ]
 
 
 
 
 
 
 
 
1564
  }
1565
  ],
1566
  "metadata": {
2IWI.pdb ADDED
The diff for this file is too large to render. See raw diff
 
2IWI_predictions.txt ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Y 32 0.32
2
+ R 33 0.91
3
+ L 34 0.65
4
+ G 35 0.50
5
+ P 36 0.82
6
+ L 37 0.90
7
+ L 38 0.19
8
+ G 39 0.33
9
+ K 40 0.10
10
+ G 41 0.23
11
+ G 42 0.54
12
+ F 43 0.58
13
+ G 44 0.25
14
+ T 45 0.06
15
+ V 46 0.10
16
+ F 47 0.20
17
+ A 48 0.34
18
+ G 49 0.10
19
+ H 50 0.59
20
+ R 51 0.12
21
+ L 52 0.03
22
+ T 53 0.86
23
+ D 54 0.08
24
+ R 55 0.57
25
+ L 56 0.96
26
+ Q 57 0.75
27
+ V 58 0.91
28
+ A 59 0.80
29
+ I 60 0.49
30
+ K 61 0.52
31
+ V 62 0.29
32
+ I 63 0.87
33
+ P 64 0.60
34
+ R 65 0.99
35
+ N 66 0.50
36
+ R 67 0.51
37
+ V 68 0.79
38
+ L 69 0.16
39
+ V 78 0.06
40
+ T 79 0.89
41
+ C 80 0.33
42
+ P 81 0.40
43
+ L 82 0.84
44
+ E 83 0.07
45
+ V 84 0.47
46
+ A 85 0.67
47
+ L 86 0.89
48
+ L 87 0.86
49
+ W 88 0.04
50
+ K 89 0.34
51
+ V 90 0.53
52
+ G 91 0.83
53
+ A 92 0.80
54
+ G 93 0.85
55
+ G 94 0.42
56
+ G 95 0.08
57
+ H 96 0.24
58
+ P 97 0.78
59
+ G 98 0.38
60
+ V 99 0.39
61
+ I 100 0.21
62
+ R 101 0.77
63
+ L 102 0.61
64
+ L 103 0.50
65
+ D 104 0.13
66
+ W 105 0.76
67
+ F 106 0.45
68
+ F 112 0.89
69
+ M 113 0.39
70
+ L 114 0.11
71
+ V 115 0.56
72
+ L 116 0.04
73
+ E 117 0.62
74
+ R 118 0.39
75
+ P 119 0.72
76
+ L 120 0.38
77
+ P 121 0.35
78
+ A 122 0.03
79
+ Q 123 0.85
80
+ D 124 0.49
81
+ L 125 0.19
82
+ F 126 0.78
83
+ D 127 0.52
84
+ Y 128 0.88
85
+ I 129 0.85
86
+ T 130 0.82
87
+ E 131 0.27
88
+ K 132 0.67
89
+ G 133 0.41
90
+ P 134 0.95
91
+ L 135 0.36
92
+ G 136 0.52
93
+ E 137 0.14
94
+ G 138 0.95
95
+ P 139 0.57
96
+ S 140 0.27
97
+ R 141 0.92
98
+ C 142 0.13
99
+ F 143 0.18
100
+ F 144 0.12
101
+ G 145 0.32
102
+ Q 146 0.35
103
+ V 147 0.95
104
+ V 148 0.89
105
+ A 149 0.76
106
+ A 150 0.43
107
+ I 151 0.09
108
+ Q 152 0.89
109
+ H 153 0.54
110
+ C 154 0.47
111
+ H 155 0.05
112
+ S 156 0.10
113
+ R 157 0.64
114
+ G 158 0.32
115
+ V 159 0.41
116
+ V 160 0.18
117
+ H 161 0.63
118
+ R 162 0.14
119
+ D 163 0.03
120
+ I 164 0.63
121
+ K 165 0.97
122
+ D 166 0.73
123
+ E 167 0.96
124
+ N 168 0.25
125
+ I 169 0.37
126
+ L 170 0.79
127
+ I 171 0.26
128
+ D 172 0.80
129
+ L 173 0.98
130
+ R 174 0.06
131
+ R 175 0.56
132
+ G 176 0.29
133
+ C 177 0.43
134
+ A 178 0.17
135
+ K 179 0.52
136
+ L 180 0.51
137
+ I 181 0.54
138
+ D 182 0.04
139
+ F 183 0.33
140
+ G 184 0.05
141
+ S 185 0.92
142
+ G 186 0.92
143
+ A 187 0.83
144
+ L 188 0.49
145
+ L 189 0.88
146
+ H 190 0.60
147
+ D 191 0.17
148
+ E 192 0.17
149
+ P 193 0.31
150
+ Y 194 0.61
151
+ T 195 0.02
152
+ D 196 0.11
153
+ F 197 0.33
154
+ D 198 0.85
155
+ G 199 0.82
156
+ T 200 0.10
157
+ R 201 0.69
158
+ V 202 0.70
159
+ Y 203 0.21
160
+ S 204 0.80
161
+ P 205 0.65
162
+ P 206 0.75
163
+ E 207 0.01
164
+ W 208 0.81
165
+ I 209 0.83
166
+ S 210 0.72
167
+ R 211 0.80
168
+ H 212 0.64
169
+ Q 213 0.36
170
+ Y 214 0.54
171
+ H 215 0.97
172
+ A 216 0.75
173
+ L 217 0.54
174
+ P 218 0.25
175
+ A 219 0.04
176
+ T 220 0.28
177
+ V 221 0.46
178
+ W 222 0.67
179
+ S 223 0.24
180
+ L 224 0.05
181
+ G 225 0.65
182
+ I 226 0.42
183
+ L 227 0.46
184
+ L 228 0.12
185
+ Y 229 0.68
186
+ D 230 0.82
187
+ M 231 0.51
188
+ V 232 0.75
189
+ C 233 0.41
190
+ G 234 0.54
191
+ D 235 0.43
192
+ I 236 0.09
193
+ P 237 0.12
194
+ F 238 0.80
195
+ E 239 0.57
196
+ R 240 0.42
197
+ D 241 0.34
198
+ Q 242 0.08
199
+ E 243 0.40
200
+ I 244 0.68
201
+ L 245 0.09
202
+ E 246 0.75
203
+ A 247 0.38
204
+ E 248 0.68
205
+ L 249 0.62
206
+ H 250 0.56
207
+ F 251 0.08
208
+ P 252 0.60
209
+ A 253 0.12
210
+ H 254 0.77
211
+ V 255 0.92
212
+ S 256 0.67
213
+ P 257 0.48
214
+ D 258 0.27
215
+ C 259 0.90
216
+ C 260 0.16
217
+ A 261 0.50
218
+ L 262 0.78
219
+ I 263 0.11
220
+ R 264 0.67
221
+ R 265 0.85
222
+ C 266 0.80
223
+ L 267 0.11
224
+ A 268 0.95
225
+ P 269 0.30
226
+ K 270 0.34
227
+ P 271 0.85
228
+ S 272 0.94
229
+ S 273 0.04
230
+ R 274 0.83
231
+ P 275 0.68
232
+ S 276 0.16
233
+ L 277 0.13
234
+ E 278 0.74
235
+ E 279 0.28
236
+ I 280 0.45
237
+ L 281 0.46
238
+ L 282 0.23
239
+ D 283 0.24
240
+ P 284 0.58
241
+ W 285 0.78
242
+ M 286 0.59
243
+ Q 287 0.30
244
+ T 288 0.30
__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (32.5 kB). View file
 
app.py CHANGED
@@ -82,6 +82,10 @@ def process_pdb(pdb_id, segment):
82
  for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
  )
 
 
 
 
85
 
86
  # Prepare input for model prediction
87
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
@@ -92,6 +96,9 @@ def process_pdb(pdb_id, segment):
92
  scores = expit(outputs[:, 1] - outputs[:, 0])
93
  normalized_scores = normalize_scores(scores)
94
 
 
 
 
95
  result_str = "\n".join([
96
  f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
97
  for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
@@ -102,14 +109,18 @@ def process_pdb(pdb_id, segment):
102
  with open(prediction_file, "w") as f:
103
  f.write(result_str)
104
 
105
- return result_str, molecule(pdb_path, normalized_scores, segment), prediction_file
106
 
107
- def molecule(input_pdb, scores=None, segment='A'):
108
  mol = read_mol(input_pdb) # Read PDB file content
109
 
110
  # Prepare high-scoring residues script if scores are provided
111
  high_score_script = ""
112
- if scores is not None:
 
 
 
 
113
  high_score_script = """
114
  // Reset all styles first
115
  viewer.getModel(0).setStyle({}, {});
@@ -127,16 +138,16 @@ def molecule(input_pdb, scores=None, segment='A'):
127
  {"stick": {"color": "red"}}
128
  );
129
 
130
- // Highlight high-scoring residues only for the selected chain
131
- let highScoreResidues2 = [%s];
132
  viewer.getModel(0).setStyle(
133
- {"chain": "%s", "resi": highScoreResidues2},
134
  {"stick": {"color": "orange"}}
135
  );
136
  """ % (segment,
137
- ", ".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),
138
  segment,
139
- ", ".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),
140
  segment)
141
 
142
  html_content = f"""
@@ -179,7 +190,7 @@ def molecule(input_pdb, scores=None, segment='A'):
179
  function(atom, viewer, event, container) {{
180
  if (!atom.label) {{
181
  atom.label = viewer.addLabel(
182
- atom.resn + ":" + atom.atom,
183
  {{
184
  position: atom,
185
  backgroundColor: 'mintcream',
@@ -246,8 +257,8 @@ with gr.Blocks() as demo:
246
  gr.Markdown("## Examples")
247
  gr.Examples(
248
  examples=[
249
- ["2IWI", "A"],
250
- ["7RPZ", "B"],
251
  ["3TJN", "C"]
252
  ],
253
  inputs=[pdb_input, segment_input],
 
82
  for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
  )
85
+ sequence2 = [
86
+ (res.id[1], res) for res in chain
87
+ if res.get_resname().strip() in aa_dict
88
+ ]
89
 
90
  # Prepare input for model prediction
91
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
 
96
  scores = expit(outputs[:, 1] - outputs[:, 0])
97
  normalized_scores = normalize_scores(scores)
98
 
99
+ # Zip residues with scores to track the residue ID and score
100
+ residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]
101
+
102
  result_str = "\n".join([
103
  f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
104
  for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
 
109
  with open(prediction_file, "w") as f:
110
  f.write(result_str)
111
 
112
+ return result_str, molecule(pdb_path, residue_scores, segment), prediction_file
113
 
114
+ def molecule(input_pdb, residue_scores=None, segment='A'):
115
  mol = read_mol(input_pdb) # Read PDB file content
116
 
117
  # Prepare high-scoring residues script if scores are provided
118
  high_score_script = ""
119
+ if residue_scores is not None:
120
+ # Sort residues based on their scores
121
+ high_score_residues = [resi for resi, score in residue_scores if score > 0.75]
122
+ mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]
123
+
124
  high_score_script = """
125
  // Reset all styles first
126
  viewer.getModel(0).setStyle({}, {});
 
138
  {"stick": {"color": "red"}}
139
  );
140
 
141
+ // Highlight medium-scoring residues only for the selected chain
142
+ let midScoreResidues = [%s];
143
  viewer.getModel(0).setStyle(
144
+ {"chain": "%s", "resi": midScoreResidues},
145
  {"stick": {"color": "orange"}}
146
  );
147
  """ % (segment,
148
+ ", ".join(str(resi) for resi in high_score_residues),
149
  segment,
150
+ ", ".join(str(resi) for resi in mid_score_residues),
151
  segment)
152
 
153
  html_content = f"""
 
190
  function(atom, viewer, event, container) {{
191
  if (!atom.label) {{
192
  atom.label = viewer.addLabel(
193
+ atom.resn + ":" +atom.resi + ":" + atom.atom,
194
  {{
195
  position: atom,
196
  backgroundColor: 'mintcream',
 
257
  gr.Markdown("## Examples")
258
  gr.Examples(
259
  examples=[
260
+ ["7RPZ", "A"],
261
+ ["2IWI", "B"],
262
  ["3TJN", "C"]
263
  ],
264
  inputs=[pdb_input, segment_input],
test2.ipynb CHANGED
@@ -473,7 +473,7 @@
473
  },
474
  {
475
  "cell_type": "code",
476
- "execution_count": 11,
477
  "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
  "metadata": {},
479
  "outputs": [
@@ -481,7 +481,7 @@
481
  "name": "stdout",
482
  "output_type": "stream",
483
  "text": [
484
- "* Running on local URL: http://127.0.0.1:7867\n",
485
  "\n",
486
  "To create a public link, set `share=True` in `launch()`.\n"
487
  ]
@@ -489,7 +489,7 @@
489
  {
490
  "data": {
491
  "text/html": [
492
- "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
  ],
494
  "text/plain": [
495
  "<IPython.core.display.HTML object>"
@@ -502,7 +502,7 @@
502
  "data": {
503
  "text/plain": []
504
  },
505
- "execution_count": 11,
506
  "metadata": {},
507
  "output_type": "execute_result"
508
  }
@@ -647,7 +647,7 @@
647
  " function(atom, viewer, event, container) {{\n",
648
  " if (!atom.label) {{\n",
649
  " atom.label = viewer.addLabel(\n",
650
- " atom.resn + \":\" + atom.atom, \n",
651
  " {{\n",
652
  " position: atom, \n",
653
  " backgroundColor: 'mintcream', \n",
@@ -727,16 +727,294 @@
727
  },
728
  {
729
  "cell_type": "code",
730
- "execution_count": null,
731
  "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  "outputs": [],
734
  "source": []
735
  },
736
  {
737
  "cell_type": "code",
738
  "execution_count": null,
739
- "id": "5eca6754-4aa1-463f-881a-25d2a0d6bb5b",
 
 
 
 
 
 
 
 
740
  "metadata": {},
741
  "outputs": [],
742
  "source": [
@@ -809,7 +1087,7 @@
809
  " except KeyError:\n",
810
  " return \"Invalid Chain ID\", None, None\n",
811
  " \n",
812
- " # Comprehensive amino acid mapping\n",
813
  " aa_dict = {\n",
814
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
815
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
@@ -819,9 +1097,14 @@
819
  " }\n",
820
  " \n",
821
  " # Exclude non-amino acid residues\n",
822
- " sequence = [\n",
823
- " residue for residue in chain \n",
 
824
  " if residue.get_resname().strip() in aa_dict\n",
 
 
 
 
825
  " ]\n",
826
  " \n",
827
  " # Prepare input for model prediction\n",
@@ -833,24 +1116,31 @@
833
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
834
  " normalized_scores = normalize_scores(scores)\n",
835
  "\n",
836
- " result_str = \"\\n\".join(\n",
837
- " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
838
- " for res, score in zip(sequence, normalized_scores)\n",
839
- " )\n",
 
 
 
840
  " \n",
841
  " # Save the predictions to a file\n",
842
  " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
843
  " with open(prediction_file, \"w\") as f:\n",
844
  " f.write(result_str)\n",
845
  " \n",
846
- " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
847
  "\n",
848
- "def molecule(input_pdb, scores=None, segment='A'):\n",
849
  " mol = read_mol(input_pdb) # Read PDB file content\n",
850
  " \n",
851
  " # Prepare high-scoring residues script if scores are provided\n",
852
  " high_score_script = \"\"\n",
853
- " if scores is not None:\n",
 
 
 
 
854
  " high_score_script = \"\"\"\n",
855
  " // Reset all styles first\n",
856
  " viewer.getModel(0).setStyle({}, {});\n",
@@ -868,16 +1158,16 @@
868
  " {\"stick\": {\"color\": \"red\"}}\n",
869
  " );\n",
870
  "\n",
871
- " // Highlight high-scoring residues only for the selected chain\n",
872
- " let highScoreResidues2 = [%s];\n",
873
  " viewer.getModel(0).setStyle(\n",
874
- " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
875
  " {\"stick\": {\"color\": \"orange\"}}\n",
876
  " );\n",
877
  " \"\"\" % (segment, \n",
878
- " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
879
  " segment,\n",
880
- " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
881
  " segment)\n",
882
  " \n",
883
  " html_content = f\"\"\"\n",
@@ -920,7 +1210,7 @@
920
  " function(atom, viewer, event, container) {{\n",
921
  " if (!atom.label) {{\n",
922
  " atom.label = viewer.addLabel(\n",
923
- " atom.resn + \":\" + atom.atom, \n",
924
  " {{\n",
925
  " position: atom, \n",
926
  " backgroundColor: 'mintcream', \n",
@@ -987,21 +1277,21 @@
987
  " gr.Markdown(\"## Examples\")\n",
988
  " gr.Examples(\n",
989
  " examples=[\n",
990
- " [\"2IWI\", \"A\"],\n",
991
- " [\"7RPZ\", \"B\"],\n",
992
  " [\"3TJN\", \"C\"]\n",
993
  " ],\n",
994
  " inputs=[pdb_input, segment_input],\n",
995
  " outputs=[predictions_output, molecule_output, download_output]\n",
996
  " )\n",
997
  "\n",
998
- "demo.launch()"
999
  ]
1000
  },
1001
  {
1002
  "cell_type": "code",
1003
  "execution_count": null,
1004
- "id": "95046d1c-ec7c-4e3e-8a98-1802cb09a25b",
1005
  "metadata": {},
1006
  "outputs": [],
1007
  "source": []
@@ -1009,11 +1299,18 @@
1009
  {
1010
  "cell_type": "code",
1011
  "execution_count": null,
1012
- "id": "a37cbe6f-d57f-41e5-8ae1-38258da39d47",
1013
  "metadata": {},
1014
  "outputs": [],
1015
  "source": [
1016
  "import gradio as gr\n",
 
 
 
 
 
 
 
1017
  "from model_loader import load_model\n",
1018
  "\n",
1019
  "import torch\n",
@@ -1022,8 +1319,6 @@
1022
  "from torch.utils.data import DataLoader\n",
1023
  "\n",
1024
  "import re\n",
1025
- "import numpy as np\n",
1026
- "import os\n",
1027
  "import pandas as pd\n",
1028
  "import copy\n",
1029
  "\n",
@@ -1035,18 +1330,6 @@
1035
  "\n",
1036
  "from scipy.special import expit\n",
1037
  "\n",
1038
- "import requests\n",
1039
- "\n",
1040
- "from gradio_molecule3d import Molecule3D\n",
1041
- "\n",
1042
- "# Biopython imports\n",
1043
- "from Bio.PDB import PDBParser, Select, PDBIO\n",
1044
- "from Bio.PDB.DSSP import DSSP\n",
1045
- "from Bio.PDB import PDBList\n",
1046
- "\n",
1047
- "from matplotlib import cm # For color mapping\n",
1048
- "from matplotlib.colors import Normalize\n",
1049
- "\n",
1050
  "# Load model and move to device\n",
1051
  "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1052
  "max_length = 1500\n",
@@ -1055,23 +1338,26 @@
1055
  "model.to(device)\n",
1056
  "model.eval()\n",
1057
  "\n",
1058
- "# Function to fetch a PDB file\n",
 
 
 
 
 
 
 
 
 
1059
  "def fetch_pdb(pdb_id):\n",
1060
  " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1061
- " pdb_path = f'pdb_files/{pdb_id}.pdb'\n",
1062
- " os.makedirs('pdb_files', exist_ok=True)\n",
1063
  " response = requests.get(pdb_url)\n",
1064
  " if response.status_code == 200:\n",
1065
  " with open(pdb_path, 'wb') as f:\n",
1066
  " f.write(response.content)\n",
1067
  " return pdb_path\n",
1068
- " return None\n",
1069
- "\n",
1070
- "\n",
1071
- "def normalize_scores(scores):\n",
1072
- " min_score = np.min(scores)\n",
1073
- " max_score = np.max(scores)\n",
1074
- " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1075
  "\n",
1076
  "def process_pdb(pdb_id, segment):\n",
1077
  " pdb_path = fetch_pdb(pdb_id)\n",
@@ -1080,9 +1366,13 @@
1080
  " \n",
1081
  " parser = PDBParser(QUIET=1)\n",
1082
  " structure = parser.get_structure('protein', pdb_path)\n",
1083
- " chain = structure[0][segment]\n",
1084
  " \n",
1085
- " # Comprehensive amino acid mapping\n",
 
 
 
 
 
1086
  " aa_dict = {\n",
1087
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1088
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
@@ -1106,67 +1396,171 @@
1106
  " # Calculate scores and normalize them\n",
1107
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1108
  " normalized_scores = normalize_scores(scores)\n",
1109
- " \n",
1110
- " # Prepare the result string, including only amino acid residues\n",
1111
  " result_str = \"\\n\".join([\n",
1112
  " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1113
  " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1114
  " ])\n",
1115
  " \n",
1116
- " # Save predictions to file\n",
1117
- " with open(f\"{pdb_id}_predictions.txt\", \"w\") as f:\n",
 
1118
  " f.write(result_str)\n",
1119
  " \n",
1120
- " return result_str, pdb_path, f\"{pdb_id}_predictions.txt\"\n",
1121
  "\n",
1122
- "reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1123
  "\n",
1124
  "# Gradio UI\n",
1125
  "with gr.Blocks() as demo:\n",
1126
- " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
 
 
 
 
 
1127
  "\n",
1128
  " with gr.Row():\n",
1129
- " pdb_input = gr.Textbox(value=\"2IWI\",\n",
1130
- " label=\"PDB ID\",\n",
1131
- " placeholder=\"Enter PDB ID here...\")\n",
1132
- " segment_input = gr.Textbox(value=\"A\",\n",
1133
- " label=\"Chain ID (Segment)\",\n",
1134
- " placeholder=\"Enter Chain ID here...\")\n",
1135
- " visualize_btn = gr.Button(\"Visualize Sructure\")\n",
1136
- " prediction_btn = gr.Button(\"Predict Ligand Binding Site\")\n",
1137
- "\n",
1138
- " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1139
  " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1140
  " download_output = gr.File(label=\"Download Predictions\")\n",
1141
- "\n",
1142
- " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
1143
- " prediction_btn.click(\n",
1144
- " process_pdb, \n",
1145
- " inputs=[pdb_input, segment_input], \n",
1146
- " outputs=[predictions_output, molecule_output, download_output]\n",
1147
- " )\n",
1148
- "\n",
1149
  " gr.Markdown(\"## Examples\")\n",
1150
  " gr.Examples(\n",
1151
  " examples=[\n",
1152
- " [\"2IWI\"],\n",
1153
- " [\"7RPZ\"],\n",
1154
- " [\"3TJN\"]\n",
1155
  " ],\n",
1156
- " inputs=[pdb_input, segment_input], \n",
1157
  " outputs=[predictions_output, molecule_output, download_output]\n",
1158
  " )\n",
1159
  "\n",
1160
  "demo.launch(share=True)"
1161
  ]
1162
- },
1163
- {
1164
- "cell_type": "code",
1165
- "execution_count": null,
1166
- "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1167
- "metadata": {},
1168
- "outputs": [],
1169
- "source": []
1170
  }
1171
  ],
1172
  "metadata": {
 
473
  },
474
  {
475
  "cell_type": "code",
476
+ "execution_count": 1,
477
  "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
  "metadata": {},
479
  "outputs": [
 
481
  "name": "stdout",
482
  "output_type": "stream",
483
  "text": [
484
+ "* Running on local URL: http://127.0.0.1:7860\n",
485
  "\n",
486
  "To create a public link, set `share=True` in `launch()`.\n"
487
  ]
 
489
  {
490
  "data": {
491
  "text/html": [
492
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
  ],
494
  "text/plain": [
495
  "<IPython.core.display.HTML object>"
 
502
  "data": {
503
  "text/plain": []
504
  },
505
+ "execution_count": 1,
506
  "metadata": {},
507
  "output_type": "execute_result"
508
  }
 
647
  " function(atom, viewer, event, container) {{\n",
648
  " if (!atom.label) {{\n",
649
  " atom.label = viewer.addLabel(\n",
650
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
651
  " {{\n",
652
  " position: atom, \n",
653
  " backgroundColor: 'mintcream', \n",
 
727
  },
728
  {
729
  "cell_type": "code",
730
+ "execution_count": 4,
731
  "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
  "metadata": {},
733
+ "outputs": [
734
+ {
735
+ "name": "stdout",
736
+ "output_type": "stream",
737
+ "text": [
738
+ "* Running on local URL: http://127.0.0.1:7863\n",
739
+ "\n",
740
+ "To create a public link, set `share=True` in `launch()`.\n"
741
+ ]
742
+ },
743
+ {
744
+ "data": {
745
+ "text/html": [
746
+ "<div><iframe src=\"http://127.0.0.1:7863/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
747
+ ],
748
+ "text/plain": [
749
+ "<IPython.core.display.HTML object>"
750
+ ]
751
+ },
752
+ "metadata": {},
753
+ "output_type": "display_data"
754
+ },
755
+ {
756
+ "data": {
757
+ "text/plain": []
758
+ },
759
+ "execution_count": 4,
760
+ "metadata": {},
761
+ "output_type": "execute_result"
762
+ }
763
+ ],
764
+ "source": [
765
+ "import gradio as gr\n",
766
+ "import requests\n",
767
+ "from Bio.PDB import PDBParser\n",
768
+ "import numpy as np\n",
769
+ "import os\n",
770
+ "from gradio_molecule3d import Molecule3D\n",
771
+ "\n",
772
+ "def read_mol(pdb_path):\n",
773
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
774
+ " with open(pdb_path, 'r') as f:\n",
775
+ " return f.read()\n",
776
+ "\n",
777
+ "def fetch_pdb(pdb_id):\n",
778
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
779
+ " pdb_path = f'{pdb_id}.pdb'\n",
780
+ " response = requests.get(pdb_url)\n",
781
+ " if response.status_code == 200:\n",
782
+ " with open(pdb_path, 'wb') as f:\n",
783
+ " f.write(response.content)\n",
784
+ " return pdb_path\n",
785
+ " else:\n",
786
+ " return None\n",
787
+ "\n",
788
+ "def process_pdb(pdb_id, segment):\n",
789
+ " pdb_path = fetch_pdb(pdb_id)\n",
790
+ " if not pdb_path:\n",
791
+ " return \"Failed to fetch PDB file\", None, None\n",
792
+ " \n",
793
+ " parser = PDBParser(QUIET=1)\n",
794
+ " structure = parser.get_structure('protein', pdb_path)\n",
795
+ " \n",
796
+ " try:\n",
797
+ " chain = structure[0][segment]\n",
798
+ " except KeyError:\n",
799
+ " return \"Invalid Chain ID\", None, None\n",
800
+ " \n",
801
+ " # Comprehensive amino acid mapping\n",
802
+ " aa_dict = {\n",
803
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
804
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
805
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
806
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
807
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
808
+ " }\n",
809
+ " \n",
810
+ " # Exclude non-amino acid residues and create a list of (resi, score) pairs\n",
811
+ " sequence = [\n",
812
+ " (res.id[1], res) for res in chain\n",
813
+ " if res.get_resname().strip() in aa_dict\n",
814
+ " ]\n",
815
+ "\n",
816
+ " random_scores = np.random.rand(len(sequence))\n",
817
+ " \n",
818
+ " # Zip residues with scores to track the residue ID and score\n",
819
+ " residue_scores = [(resi, score) for (resi, _), score in zip(sequence, random_scores)]\n",
820
+ " \n",
821
+ " result_str = \"\\n\".join(\n",
822
+ " f\"{aa_dict[chain[resi].get_resname()]} {resi} {score:.2f}\"\n",
823
+ " for resi, score in residue_scores\n",
824
+ " )\n",
825
+ " \n",
826
+ " # Save the predictions to a file\n",
827
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
828
+ " with open(prediction_file, \"w\") as f:\n",
829
+ " f.write(result_str)\n",
830
+ " \n",
831
+ " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n",
832
+ "\n",
833
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
834
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
835
+ " \n",
836
+ " # Prepare high-scoring residues script if scores are provided\n",
837
+ " high_score_script = \"\"\n",
838
+ " if residue_scores is not None:\n",
839
+ " # Sort residues based on their scores\n",
840
+ " high_score_residues = [resi for resi, score in residue_scores if score > 0.9]\n",
841
+ " mid_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 0.9]\n",
842
+ " \n",
843
+ " high_score_script = \"\"\"\n",
844
+ " // Reset all styles first\n",
845
+ " viewer.getModel(0).setStyle({}, {});\n",
846
+ " \n",
847
+ " // Show only the selected chain\n",
848
+ " viewer.getModel(0).setStyle(\n",
849
+ " {\"chain\": \"%s\"}, \n",
850
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
851
+ " );\n",
852
+ " \n",
853
+ " // Highlight high-scoring residues only for the selected chain\n",
854
+ " let highScoreResidues = [%s];\n",
855
+ " viewer.getModel(0).setStyle(\n",
856
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
857
+ " {\"stick\": {\"color\": \"red\"}}\n",
858
+ " );\n",
859
+ "\n",
860
+ " // Highlight medium-scoring residues only for the selected chain\n",
861
+ " let midScoreResidues = [%s];\n",
862
+ " viewer.getModel(0).setStyle(\n",
863
+ " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n",
864
+ " {\"stick\": {\"color\": \"orange\"}}\n",
865
+ " );\n",
866
+ " \"\"\" % (segment, \n",
867
+ " \", \".join(str(resi) for resi in high_score_residues),\n",
868
+ " segment,\n",
869
+ " \", \".join(str(resi) for resi in mid_score_residues),\n",
870
+ " segment)\n",
871
+ " \n",
872
+ " html_content = f\"\"\"\n",
873
+ " <!DOCTYPE html>\n",
874
+ " <html>\n",
875
+ " <head> \n",
876
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
877
+ " <style>\n",
878
+ " .mol-container {{\n",
879
+ " width: 100%;\n",
880
+ " height: 700px;\n",
881
+ " position: relative;\n",
882
+ " }}\n",
883
+ " </style>\n",
884
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
885
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
886
+ " </head>\n",
887
+ " <body>\n",
888
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
889
+ " <script>\n",
890
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
891
+ " $(document).ready(function () {{\n",
892
+ " let element = $(\"#container\");\n",
893
+ " let config = {{ backgroundColor: \"white\" }};\n",
894
+ " let viewer = $3Dmol.createViewer(element, config);\n",
895
+ " viewer.addModel(pdb, \"pdb\");\n",
896
+ " \n",
897
+ " // Reset all styles and show only selected chain\n",
898
+ " viewer.getModel(0).setStyle(\n",
899
+ " {{\"chain\": \"{segment}\"}}, \n",
900
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
901
+ " );\n",
902
+ " \n",
903
+ " {high_score_script}\n",
904
+ " \n",
905
+ " // Add hover functionality\n",
906
+ " viewer.setHoverable(\n",
907
+ " {{}}, \n",
908
+ " true, \n",
909
+ " function(atom, viewer, event, container) {{\n",
910
+ " if (!atom.label) {{\n",
911
+ " atom.label = viewer.addLabel(\n",
912
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
913
+ " {{\n",
914
+ " position: atom, \n",
915
+ " backgroundColor: 'mintcream', \n",
916
+ " fontColor: 'black',\n",
917
+ " fontSize: 12,\n",
918
+ " padding: 2\n",
919
+ " }}\n",
920
+ " );\n",
921
+ " }}\n",
922
+ " }},\n",
923
+ " function(atom, viewer) {{\n",
924
+ " if (atom.label) {{\n",
925
+ " viewer.removeLabel(atom.label);\n",
926
+ " delete atom.label;\n",
927
+ " }}\n",
928
+ " }}\n",
929
+ " );\n",
930
+ " \n",
931
+ " viewer.zoomTo();\n",
932
+ " viewer.render();\n",
933
+ " viewer.zoom(0.8, 2000);\n",
934
+ " }});\n",
935
+ " </script>\n",
936
+ " </body>\n",
937
+ " </html>\n",
938
+ " \"\"\"\n",
939
+ " \n",
940
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
941
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
942
+ "\n",
943
+ "reps = [\n",
944
+ " {\n",
945
+ " \"model\": 0,\n",
946
+ " \"style\": \"cartoon\",\n",
947
+ " \"color\": \"whiteCarbon\",\n",
948
+ " \"residue_range\": \"\",\n",
949
+ " \"around\": 0,\n",
950
+ " \"byres\": False,\n",
951
+ " }\n",
952
+ " ]\n",
953
+ "\n",
954
+ "# Gradio UI\n",
955
+ "with gr.Blocks() as demo:\n",
956
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
957
+ " with gr.Row():\n",
958
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
959
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
960
+ "\n",
961
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
962
+ "\n",
963
+ " with gr.Row():\n",
964
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
965
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
966
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
967
+ "\n",
968
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
969
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
970
+ " download_output = gr.File(label=\"Download Predictions\")\n",
971
+ " \n",
972
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
973
+ " \n",
974
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
975
+ " \n",
976
+ " gr.Markdown(\"## Examples\")\n",
977
+ " gr.Examples(\n",
978
+ " examples=[\n",
979
+ " [\"2IWI\", \"A\"],\n",
980
+ " [\"7RPZ\", \"B\"],\n",
981
+ " [\"3TJN\", \"C\"]\n",
982
+ " ],\n",
983
+ " inputs=[pdb_input, segment_input],\n",
984
+ " outputs=[predictions_output, molecule_output, download_output]\n",
985
+ " )\n",
986
+ "\n",
987
+ "demo.launch()"
988
+ ]
989
+ },
990
+ {
991
+ "cell_type": "code",
992
+ "execution_count": null,
993
+ "id": "6f17feec-0347-4f9d-acd4-ae681c3ed425",
994
+ "metadata": {},
995
+ "outputs": [],
996
+ "source": []
997
+ },
998
+ {
999
+ "cell_type": "code",
1000
+ "execution_count": null,
1001
+ "id": "63201f38-adde-4b12-a8d3-f23474d045cf",
1002
+ "metadata": {},
1003
  "outputs": [],
1004
  "source": []
1005
  },
1006
  {
1007
  "cell_type": "code",
1008
  "execution_count": null,
1009
+ "id": "5ccbf398-5ef2-4955-98db-99f904f8daa4",
1010
+ "metadata": {},
1011
+ "outputs": [],
1012
+ "source": []
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": null,
1017
+ "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1018
  "metadata": {},
1019
  "outputs": [],
1020
  "source": [
 
1087
  " except KeyError:\n",
1088
  " return \"Invalid Chain ID\", None, None\n",
1089
  " \n",
1090
+ " \n",
1091
  " aa_dict = {\n",
1092
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1093
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
 
1097
  " }\n",
1098
  " \n",
1099
  " # Exclude non-amino acid residues\n",
1100
+ " sequence = \"\".join(\n",
1101
+ " aa_dict[residue.get_resname().strip()] \n",
1102
+ " for residue in chain \n",
1103
  " if residue.get_resname().strip() in aa_dict\n",
1104
+ " )\n",
1105
+ " sequence2 = [\n",
1106
+ " (res.id[1], res) for res in chain\n",
1107
+ " if res.get_resname().strip() in aa_dict\n",
1108
  " ]\n",
1109
  " \n",
1110
  " # Prepare input for model prediction\n",
 
1116
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1117
  " normalized_scores = normalize_scores(scores)\n",
1118
  "\n",
1119
+ " # Zip residues with scores to track the residue ID and score\n",
1120
+ " residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]\n",
1121
+ " \n",
1122
+ " result_str = \"\\n\".join([\n",
1123
+ " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1124
+ " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1125
+ " ])\n",
1126
  " \n",
1127
  " # Save the predictions to a file\n",
1128
  " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
1129
  " with open(prediction_file, \"w\") as f:\n",
1130
  " f.write(result_str)\n",
1131
  " \n",
1132
+ " return result_str, molecule(pdb_path, residue_scores, segment), prediction_file\n",
1133
  "\n",
1134
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
1135
  " mol = read_mol(input_pdb) # Read PDB file content\n",
1136
  " \n",
1137
  " # Prepare high-scoring residues script if scores are provided\n",
1138
  " high_score_script = \"\"\n",
1139
+ " if residue_scores is not None:\n",
1140
+ " # Sort residues based on their scores\n",
1141
+ " high_score_residues = [resi for resi, score in residue_scores if score > 0.75]\n",
1142
+ " mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]\n",
1143
+ " \n",
1144
  " high_score_script = \"\"\"\n",
1145
  " // Reset all styles first\n",
1146
  " viewer.getModel(0).setStyle({}, {});\n",
 
1158
  " {\"stick\": {\"color\": \"red\"}}\n",
1159
  " );\n",
1160
  "\n",
1161
+ " // Highlight medium-scoring residues only for the selected chain\n",
1162
+ " let midScoreResidues = [%s];\n",
1163
  " viewer.getModel(0).setStyle(\n",
1164
+ " {\"chain\": \"%s\", \"resi\": midScoreResidues}, \n",
1165
  " {\"stick\": {\"color\": \"orange\"}}\n",
1166
  " );\n",
1167
  " \"\"\" % (segment, \n",
1168
+ " \", \".join(str(resi) for resi in high_score_residues),\n",
1169
  " segment,\n",
1170
+ " \", \".join(str(resi) for resi in mid_score_residues),\n",
1171
  " segment)\n",
1172
  " \n",
1173
  " html_content = f\"\"\"\n",
 
1210
  " function(atom, viewer, event, container) {{\n",
1211
  " if (!atom.label) {{\n",
1212
  " atom.label = viewer.addLabel(\n",
1213
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
1214
  " {{\n",
1215
  " position: atom, \n",
1216
  " backgroundColor: 'mintcream', \n",
 
1277
  " gr.Markdown(\"## Examples\")\n",
1278
  " gr.Examples(\n",
1279
  " examples=[\n",
1280
+ " [\"7RPZ\", \"A\"],\n",
1281
+ " [\"2IWI\", \"B\"],\n",
1282
  " [\"3TJN\", \"C\"]\n",
1283
  " ],\n",
1284
  " inputs=[pdb_input, segment_input],\n",
1285
  " outputs=[predictions_output, molecule_output, download_output]\n",
1286
  " )\n",
1287
  "\n",
1288
+ "demo.launch(share=True)"
1289
  ]
1290
  },
1291
  {
1292
  "cell_type": "code",
1293
  "execution_count": null,
1294
+ "id": "b61d06ec-a4ee-4f65-925f-d2688730416a",
1295
  "metadata": {},
1296
  "outputs": [],
1297
  "source": []
 
1299
  {
1300
  "cell_type": "code",
1301
  "execution_count": null,
1302
+ "id": "4d67d69f-1f53-4bcc-8905-8d29384c4e20",
1303
  "metadata": {},
1304
  "outputs": [],
1305
  "source": [
1306
  "import gradio as gr\n",
1307
+ "import requests\n",
1308
+ "from Bio.PDB import PDBParser\n",
1309
+ "import numpy as np\n",
1310
+ "import os\n",
1311
+ "from gradio_molecule3d import Molecule3D\n",
1312
+ "\n",
1313
+ "\n",
1314
  "from model_loader import load_model\n",
1315
  "\n",
1316
  "import torch\n",
 
1319
  "from torch.utils.data import DataLoader\n",
1320
  "\n",
1321
  "import re\n",
 
 
1322
  "import pandas as pd\n",
1323
  "import copy\n",
1324
  "\n",
 
1330
  "\n",
1331
  "from scipy.special import expit\n",
1332
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
1333
  "# Load model and move to device\n",
1334
  "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1335
  "max_length = 1500\n",
 
1338
  "model.to(device)\n",
1339
  "model.eval()\n",
1340
  "\n",
1341
+ "def normalize_scores(scores):\n",
1342
+ " min_score = np.min(scores)\n",
1343
+ " max_score = np.max(scores)\n",
1344
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1345
+ " \n",
1346
+ "def read_mol(pdb_path):\n",
1347
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
1348
+ " with open(pdb_path, 'r') as f:\n",
1349
+ " return f.read()\n",
1350
+ "\n",
1351
  "def fetch_pdb(pdb_id):\n",
1352
  " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1353
+ " pdb_path = f'{pdb_id}.pdb'\n",
 
1354
  " response = requests.get(pdb_url)\n",
1355
  " if response.status_code == 200:\n",
1356
  " with open(pdb_path, 'wb') as f:\n",
1357
  " f.write(response.content)\n",
1358
  " return pdb_path\n",
1359
+ " else:\n",
1360
+ " return None\n",
 
 
 
 
 
1361
  "\n",
1362
  "def process_pdb(pdb_id, segment):\n",
1363
  " pdb_path = fetch_pdb(pdb_id)\n",
 
1366
  " \n",
1367
  " parser = PDBParser(QUIET=1)\n",
1368
  " structure = parser.get_structure('protein', pdb_path)\n",
 
1369
  " \n",
1370
+ " try:\n",
1371
+ " chain = structure[0][segment]\n",
1372
+ " except KeyError:\n",
1373
+ " return \"Invalid Chain ID\", None, None\n",
1374
+ " \n",
1375
+ " \n",
1376
  " aa_dict = {\n",
1377
  " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1378
  " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
 
1396
  " # Calculate scores and normalize them\n",
1397
  " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1398
  " normalized_scores = normalize_scores(scores)\n",
1399
+ "\n",
 
1400
  " result_str = \"\\n\".join([\n",
1401
  " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1402
  " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1403
  " ])\n",
1404
  " \n",
1405
+ " # Save the predictions to a file\n",
1406
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
1407
+ " with open(prediction_file, \"w\") as f:\n",
1408
  " f.write(result_str)\n",
1409
  " \n",
1410
+ " return result_str, molecule(pdb_path, normalized_scores, segment), prediction_file\n",
1411
  "\n",
1412
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
1413
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
1414
+ " \n",
1415
+ " # Prepare high-scoring residues script if scores are provided\n",
1416
+ " high_score_script = \"\"\n",
1417
+ " if scores is not None:\n",
1418
+ " high_score_script = \"\"\"\n",
1419
+ " // Reset all styles first\n",
1420
+ " viewer.getModel(0).setStyle({}, {});\n",
1421
+ " \n",
1422
+ " // Show only the selected chain\n",
1423
+ " viewer.getModel(0).setStyle(\n",
1424
+ " {\"chain\": \"%s\"}, \n",
1425
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
1426
+ " );\n",
1427
+ " \n",
1428
+ " // Highlight high-scoring residues only for the selected chain\n",
1429
+ " let highScoreResidues = [%s];\n",
1430
+ " viewer.getModel(0).setStyle(\n",
1431
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
1432
+ " {\"stick\": {\"color\": \"red\"}}\n",
1433
+ " );\n",
1434
+ "\n",
1435
+ " // Highlight high-scoring residues only for the selected chain\n",
1436
+ " let highScoreResidues2 = [%s];\n",
1437
+ " viewer.getModel(0).setStyle(\n",
1438
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
1439
+ " {\"stick\": {\"color\": \"orange\"}}\n",
1440
+ " );\n",
1441
+ " \"\"\" % (segment, \n",
1442
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
1443
+ " segment,\n",
1444
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
1445
+ " segment)\n",
1446
+ " \n",
1447
+ " html_content = f\"\"\"\n",
1448
+ " <!DOCTYPE html>\n",
1449
+ " <html>\n",
1450
+ " <head> \n",
1451
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
1452
+ " <style>\n",
1453
+ " .mol-container {{\n",
1454
+ " width: 100%;\n",
1455
+ " height: 700px;\n",
1456
+ " position: relative;\n",
1457
+ " }}\n",
1458
+ " </style>\n",
1459
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
1460
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
1461
+ " </head>\n",
1462
+ " <body>\n",
1463
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
1464
+ " <script>\n",
1465
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
1466
+ " $(document).ready(function () {{\n",
1467
+ " let element = $(\"#container\");\n",
1468
+ " let config = {{ backgroundColor: \"white\" }};\n",
1469
+ " let viewer = $3Dmol.createViewer(element, config);\n",
1470
+ " viewer.addModel(pdb, \"pdb\");\n",
1471
+ " \n",
1472
+ " // Reset all styles and show only selected chain\n",
1473
+ " viewer.getModel(0).setStyle(\n",
1474
+ " {{\"chain\": \"{segment}\"}}, \n",
1475
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
1476
+ " );\n",
1477
+ " \n",
1478
+ " {high_score_script}\n",
1479
+ " \n",
1480
+ " // Add hover functionality\n",
1481
+ " viewer.setHoverable(\n",
1482
+ " {{}}, \n",
1483
+ " true, \n",
1484
+ " function(atom, viewer, event, container) {{\n",
1485
+ " if (!atom.label) {{\n",
1486
+ " atom.label = viewer.addLabel(\n",
1487
+ " atom.resn + \":\" + atom.atom, \n",
1488
+ " {{\n",
1489
+ " position: atom, \n",
1490
+ " backgroundColor: 'mintcream', \n",
1491
+ " fontColor: 'black',\n",
1492
+ " fontSize: 12,\n",
1493
+ " padding: 2\n",
1494
+ " }}\n",
1495
+ " );\n",
1496
+ " }}\n",
1497
+ " }},\n",
1498
+ " function(atom, viewer) {{\n",
1499
+ " if (atom.label) {{\n",
1500
+ " viewer.removeLabel(atom.label);\n",
1501
+ " delete atom.label;\n",
1502
+ " }}\n",
1503
+ " }}\n",
1504
+ " );\n",
1505
+ " \n",
1506
+ " viewer.zoomTo();\n",
1507
+ " viewer.render();\n",
1508
+ " viewer.zoom(0.8, 2000);\n",
1509
+ " }});\n",
1510
+ " </script>\n",
1511
+ " </body>\n",
1512
+ " </html>\n",
1513
+ " \"\"\"\n",
1514
+ " \n",
1515
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
1516
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
1517
+ "\n",
1518
+ "reps = [\n",
1519
+ " {\n",
1520
+ " \"model\": 0,\n",
1521
+ " \"style\": \"cartoon\",\n",
1522
+ " \"color\": \"whiteCarbon\",\n",
1523
+ " \"residue_range\": \"\",\n",
1524
+ " \"around\": 0,\n",
1525
+ " \"byres\": False,\n",
1526
+ " }\n",
1527
+ " ]\n",
1528
  "\n",
1529
  "# Gradio UI\n",
1530
  "with gr.Blocks() as demo:\n",
1531
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
1532
+ " with gr.Row():\n",
1533
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
1534
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
1535
+ "\n",
1536
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1537
  "\n",
1538
  " with gr.Row():\n",
1539
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
1540
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
1541
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
1542
+ "\n",
1543
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
 
 
 
 
 
1544
  " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1545
  " download_output = gr.File(label=\"Download Predictions\")\n",
1546
+ " \n",
1547
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
1548
+ " \n",
1549
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
1550
+ " \n",
 
 
 
1551
  " gr.Markdown(\"## Examples\")\n",
1552
  " gr.Examples(\n",
1553
  " examples=[\n",
1554
+ " [\"2IWI\", \"A\"],\n",
1555
+ " [\"7RPZ\", \"B\"],\n",
1556
+ " [\"3TJN\", \"C\"]\n",
1557
  " ],\n",
1558
+ " inputs=[pdb_input, segment_input],\n",
1559
  " outputs=[predictions_output, molecule_output, download_output]\n",
1560
  " )\n",
1561
  "\n",
1562
  "demo.launch(share=True)"
1563
  ]
 
 
 
 
 
 
 
 
1564
  }
1565
  ],
1566
  "metadata": {