ThorbenF commited on
Commit
b081fc7
·
1 Parent(s): d2ee732
.gradio/cached_examples/148/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PyMOL Visualization Commands,Interactive 3D Structure,Download Results,component 3,timestamp
2
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'THR',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:49:10.613413
3
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'YR',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:49:10.680800
4
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'LU',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:49:10.744581
.gradio/cached_examples/72/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PyMOL Visualization Commands,Interactive 3D Structure,Download Results,component 3,timestamp
2
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'THR',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:47:04.996998
3
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'YR',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:47:05.557021
4
+ Failed to create chain-specific PDB: invalid literal for int() with base 10: 'LU',,,"{'visible': False, '__type__': 'update'}",2024-12-27 14:47:06.032725
.ipynb_checkpoints/4BDU-checkpoint.pdb DELETED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -139,30 +139,6 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
139
 
140
  return output_pdb
141
 
142
- def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):
143
- """
144
- Calculate the geometric center of high-scoring residues
145
- """
146
- parser = PDBParser(QUIET=True)
147
- structure = parser.get_structure('protein', pdb_path)
148
-
149
- # Collect coordinates of CA atoms from high-scoring residues
150
- coords = []
151
- for model in structure:
152
- for chain in model:
153
- if chain.id == chain_id:
154
- for residue in chain:
155
- if residue.id[1] in high_score_residues:
156
- if 'CA' in residue: # Use alpha carbon as representative
157
- ca_atom = residue['CA']
158
- coords.append(ca_atom.coord)
159
-
160
- # Calculate geometric center
161
- if coords:
162
- center = np.mean(coords, axis=0)
163
- return center
164
- return None
165
-
166
  def process_pdb(pdb_id_or_file, segment):
167
  # Determine if input is a PDB ID or file path
168
  if pdb_id_or_file.endswith('.pdb'):
@@ -194,7 +170,11 @@ def process_pdb(pdb_id_or_file, segment):
194
  protein_residues = [res for res in chain if is_aa(res)]
195
  sequence = "".join(seq1(res.resname) for res in protein_residues)
196
  sequence_id = [res.id[1] for res in protein_residues]
197
-
 
 
 
 
198
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
199
  with torch.no_grad():
200
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
@@ -300,7 +280,6 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
300
  class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
301
  class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
302
 
303
-
304
  high_score_script = """
305
  // Load the original model and apply white cartoon style
306
  let chainModel = viewer.addModel(pdb, "pdb");
@@ -430,7 +409,19 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
430
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
431
 
432
  # Gradio UI
433
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
434
  gr.Markdown("# Protein Binding Site Prediction")
435
 
436
  # Mode selection
@@ -442,9 +433,9 @@ with gr.Blocks() as demo:
442
  )
443
 
444
  # Input components based on mode
445
- pdb_input = gr.Textbox(value="4BDU", label="PDB ID", placeholder="Enter PDB ID here...")
446
  pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
447
- visualize_btn = gr.Button("Visualize Structure")
448
 
449
  molecule_output2 = Molecule3D(label="Protein Structure", reps=[
450
  {
@@ -458,8 +449,9 @@ with gr.Blocks() as demo:
458
  ])
459
 
460
  with gr.Row():
461
- segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
462
- prediction_btn = gr.Button("Predict Binding Site")
 
463
 
464
  molecule_output = gr.HTML(label="Protein Structure")
465
  explanation_vis = gr.Markdown("""
@@ -533,7 +525,7 @@ with gr.Blocks() as demo:
533
  examples=[
534
  ["7RPZ", "A"],
535
  ["2IWI", "B"],
536
- ["2F6V", "A"]
537
  ],
538
  inputs=[pdb_input, segment_input],
539
  outputs=[predictions_output, molecule_output, download_output]
 
139
 
140
  return output_pdb
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def process_pdb(pdb_id_or_file, segment):
143
  # Determine if input is a PDB ID or file path
144
  if pdb_id_or_file.endswith('.pdb'):
 
170
  protein_residues = [res for res in chain if is_aa(res)]
171
  sequence = "".join(seq1(res.resname) for res in protein_residues)
172
  sequence_id = [res.id[1] for res in protein_residues]
173
+
174
+ visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
175
+ if sequence != visualized_sequence:
176
+ raise ValueError("The visualized sequence does not match the prediction sequence")
177
+
178
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
179
  with torch.no_grad():
180
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
 
280
  class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
281
  class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
282
 
 
283
  high_score_script = """
284
  // Load the original model and apply white cartoon style
285
  let chainModel = viewer.addModel(pdb, "pdb");
 
409
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
410
 
411
  # Gradio UI
412
+ with gr.Blocks(css="""
413
+ /* Customize Gradio button colors */
414
+ #visualize-btn, #predict-btn {
415
+ background-color: #FF7300; /* Deep orange */
416
+ color: white;
417
+ border-radius: 5px;
418
+ padding: 10px;
419
+ font-weight: bold;
420
+ }
421
+ #visualize-btn:hover, #predict-btn:hover {
422
+ background-color: #CC5C00; /* Darkened orange on hover */
423
+ }
424
+ """) as demo:
425
  gr.Markdown("# Protein Binding Site Prediction")
426
 
427
  # Mode selection
 
433
  )
434
 
435
  # Input components based on mode
436
+ pdb_input = gr.Textbox(value="2F6V", label="PDB ID", placeholder="Enter PDB ID here...")
437
  pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
438
+ visualize_btn = gr.Button("Visualize Structure", elem_id="visualize-btn")
439
 
440
  molecule_output2 = Molecule3D(label="Protein Structure", reps=[
441
  {
 
449
  ])
450
 
451
  with gr.Row():
452
+ segment_input = gr.Textbox(value="A", label="Chain ID (protein)", placeholder="Enter Chain ID here...",
453
+ info="Choose in which chain to predict binding sites.")
454
+ prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
455
 
456
  molecule_output = gr.HTML(label="Protein Structure")
457
  explanation_vis = gr.Markdown("""
 
525
  examples=[
526
  ["7RPZ", "A"],
527
  ["2IWI", "B"],
528
+ ["7LCJ", "R"]
529
  ],
530
  inputs=[pdb_input, segment_input],
531
  outputs=[predictions_output, molecule_output, download_output]
.ipynb_checkpoints/test-checkpoint.ipynb CHANGED
@@ -2,23 +2,24 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 3,
6
- "id": "1f8ea359-674c-4263-9c2a-7a8e7e464249",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
- "* Running on local URL: http://127.0.0.1:7862\n",
 
14
  "\n",
15
- "To create a public link, set `share=True` in `launch()`.\n"
16
  ]
17
  },
18
  {
19
  "data": {
20
  "text/html": [
21
- "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
  ],
23
  "text/plain": [
24
  "<IPython.core.display.HTML object>"
@@ -31,501 +32,344 @@
31
  "data": {
32
  "text/plain": []
33
  },
34
- "execution_count": 3,
35
  "metadata": {},
36
  "output_type": "execute_result"
37
  }
38
  ],
39
  "source": [
 
40
  "import gradio as gr\n",
41
  "import requests\n",
42
- "from Bio.PDB import PDBParser\n",
43
- "from gradio_molecule3d import Molecule3D\n",
 
 
44
  "import numpy as np\n",
45
- "\n",
46
- "# Function to fetch a PDB file from RCSB PDB\n",
47
- "def fetch_pdb(pdb_id):\n",
48
- " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
49
- " pdb_path = f'{pdb_id}.pdb'\n",
50
- " response = requests.get(pdb_url)\n",
51
- " if response.status_code == 200:\n",
52
- " with open(pdb_path, 'wb') as f:\n",
53
- " f.write(response.content)\n",
54
- " return pdb_path\n",
55
- " else:\n",
56
- " return None\n",
57
- "\n",
58
- "# Function to process the PDB file and return random predictions\n",
59
- "def process_pdb(pdb_id, segment):\n",
60
- " pdb_path = fetch_pdb(pdb_id)\n",
61
- " if not pdb_path:\n",
62
- " return \"Failed to fetch PDB file\", None, None\n",
63
- "\n",
64
- " parser = PDBParser(QUIET=True)\n",
65
- " structure = parser.get_structure('protein', pdb_path)\n",
66
- " \n",
67
- " try:\n",
68
- " chain = structure[0][segment]\n",
69
- " except KeyError:\n",
70
- " return \"Invalid Chain ID\", None, None\n",
71
- "\n",
72
- " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
73
- " random_scores = np.random.rand(len(sequence))\n",
74
- "\n",
75
- " result_str = \"\\n\".join(\n",
76
- " f\"{seq} {res.id[1]} {score:.2f}\" \n",
77
- " for seq, res, score in zip(sequence, chain, random_scores)\n",
78
- " )\n",
79
- "\n",
80
- " # Save the predictions to a file\n",
81
- " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
82
- " with open(prediction_file, \"w\") as f:\n",
83
- " f.write(result_str)\n",
84
- " \n",
85
- " return result_str, pdb_path, prediction_file\n",
86
- "\n",
87
- "#reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
88
- "\n",
89
- "reps = [\n",
90
- " {\n",
91
- " \"model\": 0,\n",
92
- " \"style\": \"cartoon\",\n",
93
- " \"color\": \"whiteCarbon\",\n",
94
- " \"residue_range\": \"\",\n",
95
- " \"around\": 0,\n",
96
- " \"byres\": False,\n",
97
- " },\n",
98
- " {\n",
99
- " \"model\": 0,\n",
100
- " \"chain\": \"A\",\n",
101
- " \"resname\": \"HIS\",\n",
102
- " \"style\": \"stick\",\n",
103
- " \"color\": \"red\"\n",
104
- " }\n",
105
- " ]\n",
106
- "\n",
107
- "\n",
108
- "# Gradio UI\n",
109
- "with gr.Blocks() as demo:\n",
110
- " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
111
- "\n",
112
- " with gr.Row():\n",
113
- " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
114
- " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
115
- " visualize_btn = gr.Button(\"Visualize Structure\")\n",
116
- " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
117
- "\n",
118
- " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
119
- " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
120
- " download_output = gr.File(label=\"Download Predictions\")\n",
121
- "\n",
122
- " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
123
- " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
124
- "\n",
125
- " gr.Markdown(\"## Examples\")\n",
126
- " gr.Examples(\n",
127
- " examples=[\n",
128
- " [\"2IWI\", \"A\"],\n",
129
- " [\"7RPZ\", \"B\"],\n",
130
- " [\"3TJN\", \"C\"]\n",
131
- " ],\n",
132
- " inputs=[pdb_input, segment_input],\n",
133
- " outputs=[predictions_output, molecule_output, download_output]\n",
134
- " )\n",
135
- "\n",
136
- "demo.launch()"
137
- ]
138
- },
139
- {
140
- "cell_type": "code",
141
- "execution_count": null,
142
- "id": "bd50ff2e-ed03-498e-8af2-73c0fb8ea07e",
143
- "metadata": {},
144
- "outputs": [],
145
- "source": []
146
- },
147
- {
148
- "cell_type": "raw",
149
- "id": "88affe12-7c48-4bd6-9e46-32cdffa729fe",
150
- "metadata": {},
151
- "source": [
152
- "import gradio as gr\n",
153
  "from gradio_molecule3d import Molecule3D\n",
154
  "\n",
 
155
  "\n",
156
- "example = Molecule3D().example_value()\n",
157
- "\n",
 
 
158
  "\n",
159
- "reps = [\n",
160
- " {\n",
161
- " \"model\": 0,\n",
162
- " \"style\": \"cartoon\",\n",
163
- " \"color\": \"whiteCarbon\",\n",
164
- " \"residue_range\": \"\",\n",
165
- " \"around\": 0,\n",
166
- " \"byres\": False,\n",
167
- " },\n",
168
- " {\n",
169
- " \"model\": 0,\n",
170
- " \"chain\": \"A\",\n",
171
- " \"resname\": \"HIS\",\n",
172
- " \"style\": \"stick\",\n",
173
- " \"color\": \"red\"\n",
174
- " }\n",
175
- " ]\n",
176
  "\n",
 
 
177
  "\n",
 
178
  "\n",
179
- "def predict(x):\n",
180
- " print(\"predict function\", x)\n",
181
- " print(x.name)\n",
182
- " return x\n",
183
  "\n",
184
- "with gr.Blocks() as demo:\n",
185
- " gr.Markdown(\"# Molecule3D\")\n",
186
- " inp = Molecule3D(label=\"Molecule3D\", reps=reps)\n",
187
- " out = Molecule3D(label=\"Output\", reps=reps)\n",
188
- "\n",
189
- " btn = gr.Button(\"Predict\")\n",
190
- " gr.Markdown(\"\"\" \n",
191
- " You can configure the default rendering of the molecule by adding a list of representations\n",
192
- " <pre>\n",
193
- " reps = [\n",
194
- " {\n",
195
- " \"model\": 0,\n",
196
- " \"style\": \"cartoon\",\n",
197
- " \"color\": \"whiteCarbon\",\n",
198
- " \"residue_range\": \"\",\n",
199
- " \"around\": 0,\n",
200
- " \"byres\": False,\n",
201
- " },\n",
202
- " {\n",
203
- " \"model\": 0,\n",
204
- " \"chain\": \"A\",\n",
205
- " \"resname\": \"HIS\",\n",
206
- " \"style\": \"stick\",\n",
207
- " \"color\": \"red\"\n",
208
- " }\n",
209
- " ]\n",
210
- " </pre>\n",
211
- " \"\"\")\n",
212
- " btn.click(predict, inputs=inp, outputs=out)\n",
213
- "\n",
214
- "\n",
215
- "if __name__ == \"__main__\":\n",
216
- " demo.launch()"
217
- ]
218
- },
219
- {
220
- "cell_type": "code",
221
- "execution_count": null,
222
- "id": "d27cc368-26a0-42c2-a68a-8833de7bb4a0",
223
- "metadata": {},
224
- "outputs": [],
225
- "source": []
226
- },
227
- {
228
- "cell_type": "raw",
229
- "id": "2b970adb-3152-427f-bb58-b92974ff406e",
230
- "metadata": {},
231
- "source": [
232
- "import gradio as gr\n",
233
- "import os\n",
234
- "import requests\n",
235
- "from Bio.PDB import PDBParser, PDBIO\n",
236
- "import biotite.structure.io as bsio\n",
237
  "\n",
238
  "def read_mol(pdb_path):\n",
239
  " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
240
  " with open(pdb_path, 'r') as f:\n",
241
  " return f.read()\n",
242
  "\n",
243
- "# Function to fetch or upload the PDB file\n",
244
- "def get_pdb(pdb_code=\"\", filepath=\"\"):\n",
245
- " if pdb_code and len(pdb_code) == 4:\n",
246
- " pdb_file = f\"{pdb_code}.pdb\"\n",
247
- " if not os.path.exists(pdb_file):\n",
248
- " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
249
- " return pdb_file\n",
250
- " elif filepath is not None:\n",
251
- " return filepath\n",
252
  " else:\n",
253
  " return None\n",
254
  "\n",
255
- "def molecule(input_pdb):\n",
256
- " mol = read_mol(input_pdb) # Read PDB file content\n",
257
- " \n",
258
- " html_content = f\"\"\"\n",
259
- " <!DOCTYPE html>\n",
260
- " <html>\n",
261
- " <head> \n",
262
- " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
263
- " <style>\n",
264
- " .mol-container {{\n",
265
- " width: 100%;\n",
266
- " height: 700px;\n",
267
- " position: relative;\n",
268
- " }}\n",
269
- " </style>\n",
270
- " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
271
- " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
272
- " </head>\n",
273
- " <body>\n",
274
- " <div id=\"container\" class=\"mol-container\"></div>\n",
275
- " <script>\n",
276
- " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
277
- " $(document).ready(function () {{\n",
278
- " let element = $(\"#container\");\n",
279
- " let config = {{ backgroundColor: \"white\" }};\n",
280
- " let viewer = $3Dmol.createViewer(element, config);\n",
281
- " viewer.addModel(pdb, \"pdb\");\n",
282
- " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
283
- " viewer.zoomTo();\n",
284
- " viewer.render();\n",
285
- " viewer.zoom(0.8, 2000);\n",
286
- " }});\n",
287
- " </script>\n",
288
- " </body>\n",
289
- " </html>\n",
290
  " \"\"\"\n",
291
- " \n",
292
- " # Return the HTML content within an iframe safely encoded for special characters\n",
293
- " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  "\n",
295
- "# Gradio function to update the visualization\n",
296
- "def update(inp, file):\n",
297
- " pdb_path = get_pdb(inp, file)\n",
298
- " if pdb_path:\n",
299
- " return molecule(pdb_path)\n",
300
- " else:\n",
301
- " return \"Invalid input. Please provide a valid PDB code or upload a PDB file.\"\n",
 
 
 
 
302
  "\n",
303
- "# Gradio UI\n",
304
- "demo = gr.Blocks()\n",
305
- "with demo:\n",
306
- " gr.Markdown(\"# PDB Viewer using 3Dmol.js\")\n",
307
- " with gr.Row():\n",
308
- " with gr.Column():\n",
309
- " inp = gr.Textbox(\n",
310
- " placeholder=\"PDB Code or upload file below\", label=\"Input structure\"\n",
311
- " )\n",
312
- " file = gr.File(file_count=\"single\")\n",
313
- " btn = gr.Button(\"View structure\")\n",
314
- " mol = gr.HTML()\n",
315
- " btn.click(fn=update, inputs=[inp, file], outputs=mol)\n",
316
  "\n",
317
- "# Launch the Gradio interface \n",
318
- "demo.launch(debug=True)"
319
- ]
320
- },
321
- {
322
- "cell_type": "code",
323
- "execution_count": null,
324
- "id": "ee215c16-a1fb-450f-bb93-37aaee6fb3f1",
325
- "metadata": {},
326
- "outputs": [],
327
- "source": []
328
- },
329
- {
330
- "cell_type": "raw",
331
- "id": "050aa2e8-2dbe-4a28-8692-58ca7c50fccd",
332
- "metadata": {},
333
- "source": [
334
- "import gradio as gr\n",
335
- "import os\n",
336
- "import requests\n",
337
- "import numpy as np\n",
338
- "from Bio.PDB import PDBParser\n",
339
  "\n",
340
- "def read_mol(pdb_path):\n",
341
- " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
342
- " with open(pdb_path, 'r') as f:\n",
343
- " return f.read()\n",
 
 
 
 
 
 
 
 
344
  "\n",
345
- "# Function to fetch a PDB file from RCSB PDB\n",
346
- "def fetch_pdb(pdb_id):\n",
347
- " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
348
- " pdb_path = f'{pdb_id}.pdb'\n",
349
- " response = requests.get(pdb_url)\n",
350
- " if response.status_code == 200:\n",
351
- " with open(pdb_path, 'wb') as f:\n",
352
- " f.write(response.content)\n",
353
- " return molecule(pdb_path)\n",
354
- " else:\n",
355
- " return None\n",
 
 
356
  "\n",
357
- "# Function to process the PDB file and return random predictions\n",
358
- "def process_pdb(pdb_id, segment):\n",
359
- " pdb_path = fetch_pdb(pdb_id)\n",
 
 
 
 
 
 
360
  " if not pdb_path:\n",
361
  " return \"Failed to fetch PDB file\", None, None\n",
362
  " \n",
363
- " parser = PDBParser(QUIET=True)\n",
364
- " structure = parser.get_structure('protein', pdb_path)\n",
 
365
  " \n",
366
  " try:\n",
 
 
 
 
 
 
 
367
  " chain = structure[0][segment]\n",
368
  " except KeyError:\n",
369
  " return \"Invalid Chain ID\", None, None\n",
370
  " \n",
371
- " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
372
- " random_scores = np.random.rand(len(sequence))\n",
373
- " result_str = \"\\n\".join(\n",
374
- " f\"{seq} {res.id[1]} {score:.2f}\" \n",
375
- " for seq, res, score in zip(sequence, chain, random_scores)\n",
376
- " )\n",
377
- " \n",
378
- " # Save the predictions to a file\n",
379
- " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
380
- " with open(prediction_file, \"w\") as f:\n",
381
- " f.write(result_str)\n",
382
  " \n",
383
- " return result_str, molecule(pdb_path), prediction_file\n",
 
384
  "\n",
385
- "def molecule(input_pdb):\n",
386
- " mol = read_mol(input_pdb) # Read PDB file content\n",
387
  " \n",
388
- " html_content = f\"\"\"\n",
389
- " <!DOCTYPE html>\n",
390
- " <html>\n",
391
- " <head> \n",
392
- " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
393
- " <style>\n",
394
- " .mol-container {{\n",
395
- " width: 100%;\n",
396
- " height: 700px;\n",
397
- " position: relative;\n",
398
- " }}\n",
399
- " </style>\n",
400
- " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
401
- " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
402
- " </head>\n",
403
- " <body>\n",
404
- " <div id=\"container\" class=\"mol-container\"></div>\n",
405
- " <script>\n",
406
- " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
407
- " $(document).ready(function () {{\n",
408
- " let element = $(\"#container\");\n",
409
- " let config = {{ backgroundColor: \"white\" }};\n",
410
- " let viewer = $3Dmol.createViewer(element, config);\n",
411
- " viewer.addModel(pdb, \"pdb\");\n",
412
- " \n",
413
- " // Set cartoon representation with white carbon color scheme\n",
414
- " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
415
- " \n",
416
- " // Highlight specific histidine residues in red stick representation\n",
417
- " viewer.getModel(0).setStyle(\n",
418
- " {{\"resn\": \"HIS\"}}, \n",
419
- " {{\"stick\": {{\"color\": \"red\"}}}}\n",
420
- " );\n",
421
- " \n",
422
- " viewer.zoomTo();\n",
423
- " viewer.render();\n",
424
- " viewer.zoom(0.8, 2000);\n",
425
- " }});\n",
426
- " </script>\n",
427
- " </body>\n",
428
- " </html>\n",
429
- " \"\"\"\n",
430
  " \n",
431
- " # Return the HTML content within an iframe safely encoded for special characters\n",
432
- " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
433
- "\n",
434
- "# Gradio UI\n",
435
- "with gr.Blocks() as demo:\n",
436
- " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
437
- " with gr.Row():\n",
438
- " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
439
- " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
440
- " visualize_btn = gr.Button(\"Visualize Structure\")\n",
441
- " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
442
  " \n",
443
- " # Use HTML output instead of Molecule3D\n",
444
- " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
445
- " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
446
- " download_output = gr.File(label=\"Download Predictions\")\n",
 
 
447
  " \n",
448
- " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
449
- " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
 
 
450
  " \n",
451
- " gr.Markdown(\"## Examples\")\n",
452
- " gr.Examples(\n",
453
- " examples=[\n",
454
- " [\"2IWI\", \"A\"],\n",
455
- " [\"7RPZ\", \"B\"],\n",
456
- " [\"3TJN\", \"C\"]\n",
457
- " ],\n",
458
- " inputs=[pdb_input, segment_input],\n",
459
- " outputs=[predictions_output, molecule_output, download_output]\n",
460
- " )\n",
461
  "\n",
462
- "demo.launch(debug=True)"
463
- ]
464
- },
465
- {
466
- "cell_type": "code",
467
- "execution_count": null,
468
- "id": "9a5facd9-855c-4b35-8dd3-2c0c8c7dd356",
469
- "metadata": {},
470
- "outputs": [],
471
- "source": []
472
- },
473
- {
474
- "cell_type": "raw",
475
- "id": "a762170f-92a9-473d-b18d-53607a780e3b",
476
- "metadata": {},
477
- "source": [
478
- "import gradio as gr\n",
479
- "import requests\n",
480
- "from Bio.PDB import PDBParser\n",
481
- "import numpy as np\n",
482
- "import os\n",
483
  "\n",
484
- "def read_mol(pdb_path):\n",
485
- " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
486
- " with open(pdb_path, 'r') as f:\n",
487
- " return f.read()\n",
488
  "\n",
489
- "# Function to fetch a PDB file from RCSB PDB\n",
490
- "def fetch_pdb(pdb_id):\n",
491
- " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
492
- " pdb_path = f'{pdb_id}.pdb'\n",
493
- " response = requests.get(pdb_url)\n",
494
- " if response.status_code == 200:\n",
495
- " with open(pdb_path, 'wb') as f:\n",
496
- " f.write(response.content)\n",
497
- " return pdb_path\n",
498
- " else:\n",
499
- " return None\n",
500
- "\n",
501
- "# Function to process the PDB file and return random predictions\n",
502
- "def process_pdb(pdb_id, segment):\n",
503
- " pdb_path = fetch_pdb(pdb_id)\n",
504
- " if not pdb_path:\n",
505
- " return \"Failed to fetch PDB file\", None, None\n",
506
- " parser = PDBParser(QUIET=True)\n",
507
- " structure = parser.get_structure('protein', pdb_path)\n",
508
  " \n",
509
- " try:\n",
510
- " chain = structure[0][segment]\n",
511
- " except KeyError:\n",
512
- " return \"Invalid Chain ID\", None, None\n",
513
- " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
514
- " random_scores = np.random.rand(len(sequence))\n",
515
- " result_str = \"\\n\".join(\n",
516
- " f\"{seq} {res.id[1]} {score:.2f}\" \n",
517
- " for seq, res, score in zip(sequence, chain, random_scores)\n",
518
- " )\n",
519
- " # Save the predictions to a file\n",
520
- " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  " with open(prediction_file, \"w\") as f:\n",
522
  " f.write(result_str)\n",
523
  " \n",
524
- " return result_str, molecule(pdb_path), prediction_file\n",
525
  "\n",
526
- "def molecule(input_pdb):\n",
 
527
  " mol = read_mol(input_pdb) # Read PDB file content\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  " \n",
 
529
  " html_content = f\"\"\"\n",
530
  " <!DOCTYPE html>\n",
531
  " <html>\n",
@@ -549,15 +393,33 @@
549
  " let element = $(\"#container\");\n",
550
  " let config = {{ backgroundColor: \"white\" }};\n",
551
  " let viewer = $3Dmol.createViewer(element, config);\n",
552
- " viewer.addModel(pdb, \"pdb\");\n",
553
  " \n",
554
- " // Set cartoon representation with white carbon color scheme\n",
555
- " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
556
  " \n",
557
- " // Highlight specific histidine residues in red stick representation\n",
558
- " viewer.getModel(0).setStyle(\n",
559
- " {{\"resn\": \"HIS\"}}, \n",
560
- " {{\"stick\": {{\"color\": \"red\"}}}}\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  " );\n",
562
  " \n",
563
  " viewer.zoomTo();\n",
@@ -573,68 +435,194 @@
573
  " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
574
  "\n",
575
  "# Gradio UI\n",
576
- "with gr.Blocks() as demo:\n",
577
- " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
578
- " with gr.Row():\n",
579
- " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
580
- " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
581
- " visualize_btn = gr.Button(\"Visualize Structure\")\n",
582
- " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
 
 
 
 
 
 
 
583
  " \n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
585
- " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
586
- " download_output = gr.File(label=\"Download Predictions\")\n",
 
 
 
 
 
 
 
 
 
587
  " \n",
588
- " # Update to explicitly use molecule() function for visualization\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  " visualize_btn.click(\n",
590
- " fn=lambda pdb_id: molecule(fetch_pdb(pdb_id)), \n",
591
- " inputs=[pdb_input], \n",
592
- " outputs=molecule_output\n",
593
  " )\n",
594
- " \n",
595
- " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
596
- " \n",
597
  " gr.Markdown(\"## Examples\")\n",
598
  " gr.Examples(\n",
599
  " examples=[\n",
600
- " [\"2IWI\", \"A\"],\n",
601
- " [\"7RPZ\", \"B\"],\n",
602
- " [\"3TJN\", \"C\"]\n",
603
  " ],\n",
604
  " inputs=[pdb_input, segment_input],\n",
605
  " outputs=[predictions_output, molecule_output, download_output]\n",
606
  " )\n",
607
  "\n",
608
- "demo.launch()"
609
  ]
610
  },
611
  {
612
  "cell_type": "code",
613
  "execution_count": null,
614
- "id": "15527a58-c449-4da0-8fab-3baaede15e41",
615
  "metadata": {},
616
  "outputs": [],
617
  "source": []
618
  },
619
  {
620
  "cell_type": "code",
621
- "execution_count": 2,
622
- "id": "9ef3e330-cb88-4c29-b84a-2f8652883cfc",
623
  "metadata": {},
624
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  {
626
  "name": "stdout",
627
  "output_type": "stream",
628
  "text": [
629
- "* Running on local URL: http://127.0.0.1:7860\n",
630
  "\n",
631
- "To create a public link, set `share=True` in `launch()`.\n"
 
 
 
 
 
 
 
632
  ]
633
  },
634
  {
635
  "data": {
636
  "text/html": [
637
- "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
638
  ],
639
  "text/plain": [
640
  "<IPython.core.display.HTML object>"
@@ -644,79 +632,510 @@
644
  "output_type": "display_data"
645
  },
646
  {
647
- "data": {
648
- "text/plain": []
649
- },
650
- "execution_count": 2,
651
- "metadata": {},
652
- "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  }
654
  ],
655
  "source": [
 
656
  "import gradio as gr\n",
657
  "import requests\n",
658
- "from Bio.PDB import PDBParser\n",
 
 
 
659
  "import numpy as np\n",
660
  "import os\n",
661
  "from gradio_molecule3d import Molecule3D\n",
 
 
 
 
 
 
 
 
 
 
662
  "\n",
663
- "def read_mol(pdb_path):\n",
664
- " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
665
- " with open(pdb_path, 'r') as f:\n",
666
- " return f.read()\n",
667
  "\n",
668
- "def fetch_pdb(pdb_id):\n",
669
- " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
670
- " pdb_path = f'{pdb_id}.pdb'\n",
671
- " response = requests.get(pdb_url)\n",
672
- " if response.status_code == 200:\n",
673
- " with open(pdb_path, 'wb') as f:\n",
674
- " f.write(response.content)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  " return pdb_path\n",
676
- " else:\n",
677
- " return None\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  "\n",
679
- "def process_pdb(pdb_id, segment):\n",
680
- " pdb_path = fetch_pdb(pdb_id)\n",
681
- " if not pdb_path:\n",
682
- " return \"Failed to fetch PDB file\", None, None\n",
683
- " parser = PDBParser(QUIET=True)\n",
684
- " structure = parser.get_structure('protein', pdb_path)\n",
685
- " \n",
686
  " try:\n",
687
- " chain = structure[0][segment]\n",
688
- " except KeyError:\n",
689
- " return \"Invalid Chain ID\", None, None\n",
690
- " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
691
- " random_scores = np.random.rand(len(sequence))\n",
692
- " result_str = \"\\n\".join(\n",
693
- " f\"{seq} {res.id[1]} {score:.2f}\" \n",
694
- " for seq, res, score in zip(sequence, chain, random_scores)\n",
695
- " )\n",
696
- " # Save the predictions to a file\n",
697
- " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
698
- " with open(prediction_file, \"w\") as f:\n",
699
- " f.write(result_str)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  " \n",
701
- " return result_str, molecule(pdb_path, random_scores), prediction_file\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  "\n",
703
- "def molecule(input_pdb, scores=None):\n",
704
- " mol = read_mol(input_pdb) # Read PDB file content\n",
 
705
  " \n",
706
- " # Prepare high-scoring residues script if scores are provided\n",
707
- " high_score_script = \"\"\n",
708
- " if scores is not None:\n",
709
- " high_score_script = \"\"\"\n",
710
- " // Highlight residues with high scores\n",
711
- " let highScoreResidues = [{}];\n",
712
- " viewer.getModel(0).setStyle(\n",
713
- " {{\"resi\": highScoreResidues}}, \n",
714
- " {{\"stick\": {{\"color\": \"red\"}}}}\n",
715
- " );\n",
716
- " \"\"\".format(\n",
717
- " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8)\n",
718
- " )\n",
 
 
 
 
 
 
719
  " \n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  " html_content = f\"\"\"\n",
721
  " <!DOCTYPE html>\n",
722
  " <html>\n",
@@ -728,95 +1147,343 @@
728
  " height: 700px;\n",
729
  " position: relative;\n",
730
  " }}\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  " </style>\n",
732
  " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
733
  " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
734
  " </head>\n",
735
  " <body>\n",
736
- " <div id=\"container\" class=\"mol-container\"></div>\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  " <script>\n",
738
- " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
 
 
 
739
  " $(document).ready(function () {{\n",
740
  " let element = $(\"#container\");\n",
741
  " let config = {{ backgroundColor: \"white\" }};\n",
742
- " let viewer = $3Dmol.createViewer(element, config);\n",
743
- " viewer.addModel(pdb, \"pdb\");\n",
744
  " \n",
745
- " // Set cartoon representation with white carbon color scheme\n",
746
- " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
747
  " \n",
748
- " {high_score_script}\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
  " \n",
750
  " viewer.zoomTo();\n",
751
  " viewer.render();\n",
752
  " viewer.zoom(0.8, 2000);\n",
753
  " }});\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  " </script>\n",
755
  " </body>\n",
756
  " </html>\n",
757
  " \"\"\"\n",
758
  " \n",
759
- " # Return the HTML content within an iframe safely encoded for special characters\n",
760
  " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
761
  "\n",
762
- "reps = [\n",
763
- " {\n",
764
- " \"model\": 0,\n",
765
- " \"style\": \"cartoon\",\n",
766
- " \"color\": \"whiteCarbon\",\n",
767
- " \"residue_range\": \"\",\n",
768
- " \"around\": 0,\n",
769
- " \"byres\": False,\n",
770
- " }\n",
771
- " ]\n",
772
  "# Gradio UI\n",
773
- "with gr.Blocks() as demo:\n",
774
- " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
775
- " with gr.Row():\n",
776
- " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
777
- " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
778
- " visualize_btn = gr.Button(\"Visualize Structure\")\n",
779
- " #prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  "\n",
781
- " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
 
 
 
 
 
 
 
 
 
 
 
782
  "\n",
783
- " with gr.Row():\n",
784
- " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
785
- " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
786
- " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
 
 
 
 
 
 
 
 
787
  "\n",
788
- " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
789
- " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
790
- " download_output = gr.File(label=\"Download Predictions\")\n",
791
- " \n",
792
- " #visualize_btn.click(\n",
793
- " # fn=lambda pdb_id: molecule(fetch_pdb(pdb_id)), \n",
794
- " # inputs=[pdb_input], \n",
795
- " # outputs=molecule_output\n",
796
- " #)\n",
797
- " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
798
- " \n",
799
- " \n",
800
- " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
801
- " \n",
802
- " gr.Markdown(\"## Examples\")\n",
803
- " gr.Examples(\n",
804
- " examples=[\n",
805
- " [\"2IWI\", \"A\"],\n",
806
- " [\"7RPZ\", \"B\"],\n",
807
- " [\"3TJN\", \"C\"]\n",
808
- " ],\n",
809
- " inputs=[pdb_input, segment_input],\n",
810
- " outputs=[predictions_output, molecule_output, download_output]\n",
811
- " )\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  "\n",
813
- "demo.launch()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
  ]
815
  },
816
  {
817
  "cell_type": "code",
818
  "execution_count": null,
819
- "id": "14605615-8610-4d9e-841b-db7618cde844",
 
 
 
 
 
 
 
 
820
  "metadata": {},
821
  "outputs": [],
822
  "source": []
@@ -838,7 +1505,7 @@
838
  "name": "python",
839
  "nbconvert_exporter": "python",
840
  "pygments_lexer": "ipython3",
841
- "version": "3.12.7"
842
  }
843
  },
844
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 29,
6
+ "id": "e776d9d6-417e-46d4-8061-846c055e1f8a",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
+ "* Running on local URL: http://127.0.0.1:7873\n",
14
+ "* Running on public URL: https://120000a6aa9d78e04c.gradio.live\n",
15
  "\n",
16
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
17
  ]
18
  },
19
  {
20
  "data": {
21
  "text/html": [
22
+ "<div><iframe src=\"https://120000a6aa9d78e04c.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
23
  ],
24
  "text/plain": [
25
  "<IPython.core.display.HTML object>"
 
32
  "data": {
33
  "text/plain": []
34
  },
35
+ "execution_count": 29,
36
  "metadata": {},
37
  "output_type": "execute_result"
38
  }
39
  ],
40
  "source": [
41
+ "from datetime import datetime\n",
42
  "import gradio as gr\n",
43
  "import requests\n",
44
+ "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select\n",
45
+ "from Bio.PDB.Polypeptide import is_aa\n",
46
+ "from Bio.SeqUtils import seq1\n",
47
+ "from typing import Optional, Tuple\n",
48
  "import numpy as np\n",
49
+ "import os\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "from gradio_molecule3d import Molecule3D\n",
51
  "\n",
52
+ "#from model_loader import load_model\n",
53
  "\n",
54
+ "import torch\n",
55
+ "import torch.nn as nn\n",
56
+ "import torch.nn.functional as F\n",
57
+ "from torch.utils.data import DataLoader\n",
58
  "\n",
59
+ "import re\n",
60
+ "import pandas as pd\n",
61
+ "import copy\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  "\n",
63
+ "#import transformers\n",
64
+ "#from transformers import AutoTokenizer, DataCollatorForTokenClassification\n",
65
  "\n",
66
+ "#from datasets import Dataset\n",
67
  "\n",
68
+ "from scipy.special import expit\n",
 
 
 
69
  "\n",
70
+ "def normalize_scores(scores):\n",
71
+ " min_score = np.min(scores)\n",
72
+ " max_score = np.max(scores)\n",
73
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  "\n",
75
  "def read_mol(pdb_path):\n",
76
  " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
77
  " with open(pdb_path, 'r') as f:\n",
78
  " return f.read()\n",
79
  "\n",
80
+ "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
81
+ " \"\"\"\n",
82
+ " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
83
+ " If a structure file already exists locally, it uses that.\n",
84
+ " \"\"\"\n",
85
+ " file_path = download_structure(pdb_id, output_dir)\n",
86
+ " if file_path:\n",
87
+ " return file_path\n",
 
88
  " else:\n",
89
  " return None\n",
90
  "\n",
91
+ "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  " \"\"\"\n",
93
+ " Attempt to download the structure file in CIF or PDB format.\n",
94
+ " Returns the path to the downloaded file, or None if download fails.\n",
95
+ " \"\"\"\n",
96
+ " for ext in ['.cif', '.pdb']:\n",
97
+ " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
98
+ " if os.path.exists(file_path):\n",
99
+ " return file_path\n",
100
+ " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
101
+ " try:\n",
102
+ " response = requests.get(url, timeout=10)\n",
103
+ " if response.status_code == 200:\n",
104
+ " with open(file_path, 'wb') as f:\n",
105
+ " f.write(response.content)\n",
106
+ " return file_path\n",
107
+ " except Exception as e:\n",
108
+ " print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
109
+ " return None\n",
110
  "\n",
111
+ "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
112
+ " \"\"\"\n",
113
+ " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
114
+ " \"\"\"\n",
115
+ " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
116
+ " parser = MMCIFParser(QUIET=True)\n",
117
+ " structure = parser.get_structure('protein', cif_path)\n",
118
+ " io = PDBIO()\n",
119
+ " io.set_structure(structure)\n",
120
+ " io.save(pdb_path)\n",
121
+ " return pdb_path\n",
122
  "\n",
123
+ "def fetch_pdb(pdb_id):\n",
124
+ " pdb_path = fetch_structure(pdb_id)\n",
125
+ " if not pdb_path:\n",
126
+ " return None\n",
127
+ " _, ext = os.path.splitext(pdb_path)\n",
128
+ " if ext == '.cif':\n",
129
+ " pdb_path = convert_cif_to_pdb(pdb_path)\n",
130
+ " return pdb_path\n",
 
 
 
 
 
131
  "\n",
132
+ "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
133
+ " \"\"\"\n",
134
+ " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
135
+ " \"\"\"\n",
136
+ " # Read the original PDB file\n",
137
+ " parser = PDBParser(QUIET=True)\n",
138
+ " structure = parser.get_structure('protein', input_pdb)\n",
139
+ " \n",
140
+ " # Prepare a new structure with only the specified chain and selected residues\n",
141
+ " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
142
+ " \n",
143
+ " # Create scores dictionary for easy lookup\n",
144
+ " scores_dict = {resi: score for resi, score in residue_scores}\n",
 
 
 
 
 
 
 
 
 
145
  "\n",
146
+ " # Create a custom Select class\n",
147
+ " class ResidueSelector(Select):\n",
148
+ " def __init__(self, chain_id, selected_residues, scores_dict):\n",
149
+ " self.chain_id = chain_id\n",
150
+ " self.selected_residues = selected_residues\n",
151
+ " self.scores_dict = scores_dict\n",
152
+ " \n",
153
+ " def accept_chain(self, chain):\n",
154
+ " return chain.id == self.chain_id\n",
155
+ " \n",
156
+ " def accept_residue(self, residue):\n",
157
+ " return residue.id[1] in self.selected_residues\n",
158
  "\n",
159
+ " def accept_atom(self, atom):\n",
160
+ " if atom.parent.id[1] in self.scores_dict:\n",
161
+ " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n",
162
+ " return True\n",
163
+ "\n",
164
+ " # Prepare output PDB with selected chain and residues, modified B-factors\n",
165
+ " io = PDBIO()\n",
166
+ " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
167
+ " \n",
168
+ " io.set_structure(structure[0])\n",
169
+ " io.save(output_pdb, selector)\n",
170
+ " \n",
171
+ " return output_pdb\n",
172
  "\n",
173
+ "def process_pdb(pdb_id_or_file, segment):\n",
174
+ " # Determine if input is a PDB ID or file path\n",
175
+ " if pdb_id_or_file.endswith('.pdb'):\n",
176
+ " pdb_path = pdb_id_or_file\n",
177
+ " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
178
+ " else:\n",
179
+ " pdb_id = pdb_id_or_file\n",
180
+ " pdb_path = fetch_pdb(pdb_id)\n",
181
+ " \n",
182
  " if not pdb_path:\n",
183
  " return \"Failed to fetch PDB file\", None, None\n",
184
  " \n",
185
+ " # Determine the file format and choose the appropriate parser\n",
186
+ " _, ext = os.path.splitext(pdb_path)\n",
187
+ " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
188
  " \n",
189
  " try:\n",
190
+ " # Parse the structure file\n",
191
+ " structure = parser.get_structure('protein', pdb_path)\n",
192
+ " except Exception as e:\n",
193
+ " return f\"Error parsing structure file: {e}\", None, None\n",
194
+ " \n",
195
+ " # Extract the specified chain\n",
196
+ " try:\n",
197
  " chain = structure[0][segment]\n",
198
  " except KeyError:\n",
199
  " return \"Invalid Chain ID\", None, None\n",
200
  " \n",
201
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
202
+ " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
203
+ " sequence_id = [res.id[1] for res in protein_residues]\n",
204
+ "\n",
205
+ " visualized_sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
206
+ " if sequence != visualized_sequence:\n",
207
+ " raise ValueError(\"The visualized sequence does not match the prediction sequence\")\n",
208
+ " \n",
209
+ " scores = np.random.rand(len(sequence))\n",
210
+ " normalized_scores = normalize_scores(scores)\n",
 
211
  " \n",
212
+ " # Zip residues with scores to track the residue ID and score\n",
213
+ " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
214
  "\n",
 
 
215
  " \n",
216
+ " # Define the score brackets\n",
217
+ " score_brackets = {\n",
218
+ " \"0.0-0.2\": (0.0, 0.2),\n",
219
+ " \"0.2-0.4\": (0.2, 0.4),\n",
220
+ " \"0.4-0.6\": (0.4, 0.6),\n",
221
+ " \"0.6-0.8\": (0.6, 0.8),\n",
222
+ " \"0.8-1.0\": (0.8, 1.0)\n",
223
+ " }\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  " \n",
225
+ " # Initialize a dictionary to store residues by bracket\n",
226
+ " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n",
 
 
 
 
 
 
 
 
 
227
  " \n",
228
+ " # Categorize residues into brackets\n",
229
+ " for resi, score in residue_scores:\n",
230
+ " for bracket, (lower, upper) in score_brackets.items():\n",
231
+ " if lower <= score < upper:\n",
232
+ " residues_by_bracket[bracket].append(resi)\n",
233
+ " break\n",
234
  " \n",
235
+ " # Preparing the result string\n",
236
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
237
+ " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
238
+ " result_str += \"Residues by Score Brackets:\\n\\n\"\n",
239
  " \n",
240
+ " # Add residues for each bracket\n",
241
+ " for bracket, residues in residues_by_bracket.items():\n",
242
+ " result_str += f\"Bracket {bracket}:\\n\"\n",
243
+ " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\"\n",
244
+ " result_str += \"\\n\".join([\n",
245
+ " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
246
+ " for i, res in enumerate(protein_residues) if res.id[1] in residues\n",
247
+ " ])\n",
248
+ " result_str += \"\\n\\n\"\n",
 
249
  "\n",
250
+ " # Create chain-specific PDB with scores in B-factor\n",
251
+ " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  "\n",
253
+ " # Molecule visualization with updated script with color mapping\n",
254
+ " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
 
 
255
  "\n",
256
+ " # Improved PyMOL command suggestions\n",
257
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
258
+ " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  " \n",
260
+ " pymol_commands += f\"\"\"\n",
261
+ " # PyMOL Visualization Commands\n",
262
+ " load {os.path.abspath(pdb_path)}, protein\n",
263
+ " hide everything, all\n",
264
+ " show cartoon, chain {segment}\n",
265
+ " color white, chain {segment}\n",
266
+ " \"\"\"\n",
267
+ " \n",
268
+ " # Define colors for each score bracket\n",
269
+ " bracket_colors = {\n",
270
+ " \"0.0-0.2\": \"white\",\n",
271
+ " \"0.2-0.4\": \"lightorange\",\n",
272
+ " \"0.4-0.6\": \"orange\",\n",
273
+ " \"0.6-0.8\": \"orangered\",\n",
274
+ " \"0.8-1.0\": \"red\"\n",
275
+ " }\n",
276
+ " \n",
277
+ " # Add PyMOL commands for each score bracket\n",
278
+ " for bracket, residues in residues_by_bracket.items():\n",
279
+ " if residues: # Only add commands if there are residues in this bracket\n",
280
+ " color = bracket_colors[bracket]\n",
281
+ " resi_list = '+'.join(map(str, residues))\n",
282
+ " pymol_commands += f\"\"\"\n",
283
+ " select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}\n",
284
+ " show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}\n",
285
+ " color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}\n",
286
+ " \"\"\"\n",
287
+ " \n",
288
+ " # Create prediction and scored PDB files\n",
289
+ " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
290
  " with open(prediction_file, \"w\") as f:\n",
291
  " f.write(result_str)\n",
292
  " \n",
293
+ " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n",
294
  "\n",
295
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
296
+ " # More granular scoring for visualization\n",
297
  " mol = read_mol(input_pdb) # Read PDB file content\n",
298
+ "\n",
299
+ " # Prepare high-scoring residues script if scores are provided\n",
300
+ " high_score_script = \"\"\n",
301
+ " if residue_scores is not None:\n",
302
+ " # Filter residues based on their scores\n",
303
+ " class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2]\n",
304
+ " class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4]\n",
305
+ " class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6]\n",
306
+ " class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]\n",
307
+ " class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]\n",
308
+ " \n",
309
+ " high_score_script = \"\"\"\n",
310
+ " // Load the original model and apply white cartoon style\n",
311
+ " let chainModel = viewer.addModel(pdb, \"pdb\");\n",
312
+ " chainModel.setStyle({}, {});\n",
313
+ " chainModel.setStyle(\n",
314
+ " {\"chain\": \"%s\"}, \n",
315
+ " {\"cartoon\": {\"color\": \"white\"}}\n",
316
+ " );\n",
317
+ "\n",
318
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
319
+ " let class1Model = viewer.addModel(pdb, \"pdb\");\n",
320
+ " class1Model.setStyle({}, {});\n",
321
+ " class1Model.setStyle(\n",
322
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
323
+ " {\"stick\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}}\n",
324
+ " );\n",
325
+ "\n",
326
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
327
+ " let class2Model = viewer.addModel(pdb, \"pdb\");\n",
328
+ " class2Model.setStyle({}, {});\n",
329
+ " class2Model.setStyle(\n",
330
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
331
+ " {\"stick\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}}\n",
332
+ " );\n",
333
+ "\n",
334
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
335
+ " let class3Model = viewer.addModel(pdb, \"pdb\");\n",
336
+ " class3Model.setStyle({}, {});\n",
337
+ " class3Model.setStyle(\n",
338
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
339
+ " {\"stick\": {\"color\": \"0xFFA500\", \"opacity\": 1}}\n",
340
+ " );\n",
341
+ "\n",
342
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
343
+ " let class4Model = viewer.addModel(pdb, \"pdb\");\n",
344
+ " class4Model.setStyle({}, {});\n",
345
+ " class4Model.setStyle(\n",
346
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
347
+ " {\"stick\": {\"color\": \"0xFF4500\", \"opacity\": 1}}\n",
348
+ " );\n",
349
+ "\n",
350
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
351
+ " let class5Model = viewer.addModel(pdb, \"pdb\");\n",
352
+ " class5Model.setStyle({}, {});\n",
353
+ " class5Model.setStyle(\n",
354
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
355
+ " {\"stick\": {\"color\": \"0xFF0000\", \"alpha\": 1}}\n",
356
+ " );\n",
357
+ "\n",
358
+ " \"\"\" % (\n",
359
+ " segment,\n",
360
+ " segment,\n",
361
+ " \", \".join(str(resi) for resi in class1_score_residues),\n",
362
+ " segment,\n",
363
+ " \", \".join(str(resi) for resi in class2_score_residues),\n",
364
+ " segment,\n",
365
+ " \", \".join(str(resi) for resi in class3_score_residues),\n",
366
+ " segment,\n",
367
+ " \", \".join(str(resi) for resi in class4_score_residues),\n",
368
+ " segment,\n",
369
+ " \", \".join(str(resi) for resi in class5_score_residues)\n",
370
+ " )\n",
371
  " \n",
372
+ " # Generate the full HTML content\n",
373
  " html_content = f\"\"\"\n",
374
  " <!DOCTYPE html>\n",
375
  " <html>\n",
 
393
  " let element = $(\"#container\");\n",
394
  " let config = {{ backgroundColor: \"white\" }};\n",
395
  " let viewer = $3Dmol.createViewer(element, config);\n",
 
396
  " \n",
397
+ " {high_score_script}\n",
 
398
  " \n",
399
+ " // Add hover functionality\n",
400
+ " viewer.setHoverable(\n",
401
+ " {{}}, \n",
402
+ " true, \n",
403
+ " function(atom, viewer, event, container) {{\n",
404
+ " if (!atom.label) {{\n",
405
+ " atom.label = viewer.addLabel(\n",
406
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
407
+ " {{\n",
408
+ " position: atom, \n",
409
+ " backgroundColor: 'mintcream', \n",
410
+ " fontColor: 'black',\n",
411
+ " fontSize: 18,\n",
412
+ " padding: 4\n",
413
+ " }}\n",
414
+ " );\n",
415
+ " }}\n",
416
+ " }},\n",
417
+ " function(atom, viewer) {{\n",
418
+ " if (atom.label) {{\n",
419
+ " viewer.removeLabel(atom.label);\n",
420
+ " delete atom.label;\n",
421
+ " }}\n",
422
+ " }}\n",
423
  " );\n",
424
  " \n",
425
  " viewer.zoomTo();\n",
 
435
  " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
436
  "\n",
437
  "# Gradio UI\n",
438
+ "with gr.Blocks(css=\"\"\"\n",
439
+ " /* Customize Gradio button colors */\n",
440
+ " #visualize-btn, #predict-btn {\n",
441
+ " background-color: #FF7300; /* Deep orange */\n",
442
+ " color: white;\n",
443
+ " border-radius: 5px;\n",
444
+ " padding: 10px;\n",
445
+ " font-weight: bold;\n",
446
+ " }\n",
447
+ " #visualize-btn:hover, #predict-btn:hover {\n",
448
+ " background-color: #CC5C00; /* Darkened orange on hover */\n",
449
+ " }\n",
450
+ "\"\"\") as demo:\n",
451
+ " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
452
  " \n",
453
+ " # Mode selection\n",
454
+ " mode = gr.Radio(\n",
455
+ " choices=[\"PDB ID\", \"Upload File\"],\n",
456
+ " value=\"PDB ID\",\n",
457
+ " label=\"Input Mode\",\n",
458
+ " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file.\"\n",
459
+ " )\n",
460
+ "\n",
461
+ " # Input components based on mode\n",
462
+ " pdb_input = gr.Textbox(value=\"2F6V\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
463
+ " pdb_file = gr.File(label=\"Upload PDB/CIF File\", visible=False)\n",
464
+ " visualize_btn = gr.Button(\"Visualize Structure\", elem_id=\"visualize-btn\")\n",
465
+ "\n",
466
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
467
+ " {\n",
468
+ " \"model\": 0,\n",
469
+ " \"style\": \"cartoon\",\n",
470
+ " \"color\": \"whiteCarbon\",\n",
471
+ " \"residue_range\": \"\",\n",
472
+ " \"around\": 0,\n",
473
+ " \"byres\": False,\n",
474
+ " }\n",
475
+ " ])\n",
476
+ "\n",
477
+ " with gr.Row():\n",
478
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID (protein)\", placeholder=\"Enter Chain ID here...\",\n",
479
+ " info=\"Choose in which chain to predict binding sites.\")\n",
480
+ " prediction_btn = gr.Button(\"Predict Binding Site\", elem_id=\"predict-btn\")\n",
481
+ "\n",
482
  " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
483
+ " explanation_vis = gr.Markdown(\"\"\"\n",
484
+ " Score dependent colorcoding:\n",
485
+ " - 0.0-0.2: white \n",
486
+ " - 0.2–0.4: light orange \n",
487
+ " - 0.4–0.6: orange\n",
488
+ " - 0.6–0.8: orangered\n",
489
+ " - 0.8–1.0: red\n",
490
+ " \"\"\")\n",
491
+ " predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n",
492
+ " gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n",
493
+ " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
494
  " \n",
495
+ " def process_interface(mode, pdb_id, pdb_file, chain_id):\n",
496
+ " if mode == \"PDB ID\":\n",
497
+ " return process_pdb(pdb_id, chain_id)\n",
498
+ " elif mode == \"Upload File\":\n",
499
+ " _, ext = os.path.splitext(pdb_file.name)\n",
500
+ " file_path = os.path.join('./', f\"{_}{ext}\")\n",
501
+ " if ext == '.cif':\n",
502
+ " pdb_path = convert_cif_to_pdb(file_path)\n",
503
+ " else:\n",
504
+ " pdb_path= file_path\n",
505
+ " return process_pdb(pdb_path, chain_id)\n",
506
+ " else:\n",
507
+ " return \"Error: Invalid mode selected\", None, None\n",
508
+ "\n",
509
+ " def fetch_interface(mode, pdb_id, pdb_file):\n",
510
+ " if mode == \"PDB ID\":\n",
511
+ " return fetch_pdb(pdb_id)\n",
512
+ " elif mode == \"Upload File\":\n",
513
+ " _, ext = os.path.splitext(pdb_file.name)\n",
514
+ " file_path = os.path.join('./', f\"{_}{ext}\")\n",
515
+ " #print(ext)\n",
516
+ " if ext == '.cif':\n",
517
+ " pdb_path = convert_cif_to_pdb(file_path)\n",
518
+ " else:\n",
519
+ " pdb_path= file_path\n",
520
+ " #print(pdb_path)\n",
521
+ " return pdb_path\n",
522
+ " else:\n",
523
+ " return \"Error: Invalid mode selected\"\n",
524
+ "\n",
525
+ " def toggle_mode(selected_mode):\n",
526
+ " if selected_mode == \"PDB ID\":\n",
527
+ " return gr.update(visible=True), gr.update(visible=False)\n",
528
+ " else:\n",
529
+ " return gr.update(visible=False), gr.update(visible=True)\n",
530
+ "\n",
531
+ " mode.change(\n",
532
+ " toggle_mode,\n",
533
+ " inputs=[mode],\n",
534
+ " outputs=[pdb_input, pdb_file]\n",
535
+ " )\n",
536
+ "\n",
537
+ " prediction_btn.click(\n",
538
+ " process_interface, \n",
539
+ " inputs=[mode, pdb_input, pdb_file, segment_input], \n",
540
+ " outputs=[predictions_output, molecule_output, download_output]\n",
541
+ " )\n",
542
+ "\n",
543
  " visualize_btn.click(\n",
544
+ " fetch_interface, \n",
545
+ " inputs=[mode, pdb_input, pdb_file], \n",
546
+ " outputs=molecule_output2\n",
547
  " )\n",
548
+ "\n",
 
 
549
  " gr.Markdown(\"## Examples\")\n",
550
  " gr.Examples(\n",
551
  " examples=[\n",
552
+ " [\"7RPZ\", \"A\"],\n",
553
+ " [\"2IWI\", \"B\"],\n",
554
+ " [\"7LCJ\", \"R\"]\n",
555
  " ],\n",
556
  " inputs=[pdb_input, segment_input],\n",
557
  " outputs=[predictions_output, molecule_output, download_output]\n",
558
  " )\n",
559
  "\n",
560
+ "demo.launch(share=True)"
561
  ]
562
  },
563
  {
564
  "cell_type": "code",
565
  "execution_count": null,
566
+ "id": "440c87ed-45c9-4501-b208-409cbfd7858b",
567
  "metadata": {},
568
  "outputs": [],
569
  "source": []
570
  },
571
  {
572
  "cell_type": "code",
573
+ "execution_count": 21,
574
+ "id": "d70c40b9-5d5a-4795-b2a2-149c4a57d16e",
575
  "metadata": {},
576
  "outputs": [
577
+ {
578
+ "name": "stderr",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py:441: UserWarning: Examples will be cached but not all input components have example values. This may result in an exception being thrown by your function. If you do get an error while caching examples, make sure all of your inputs have example values for all of your examples or you provide default values for those particular parameters in your function.\n",
582
+ " warnings.warn(\n",
583
+ "INFO:__main__:Using cached structure: ./7rpz.cif\n",
584
+ "INFO:__main__:Using cached structure: ./2iwi.cif\n",
585
+ "INFO:__main__:Using cached structure: ./2f6v.cif\n",
586
+ "INFO:httpx:HTTP Request: GET http://127.0.0.1:7862/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n"
587
+ ]
588
+ },
589
+ {
590
+ "name": "stdout",
591
+ "output_type": "stream",
592
+ "text": [
593
+ "* Running on local URL: http://127.0.0.1:7862\n",
594
+ "Caching examples at: '/home/frohlkin/Projects/LargeLanguageModels/Publication/test_webpage/.gradio/cached_examples/148'\n"
595
+ ]
596
+ },
597
+ {
598
+ "name": "stderr",
599
+ "output_type": "stream",
600
+ "text": [
601
+ "INFO:httpx:HTTP Request: HEAD http://127.0.0.1:7862/ \"HTTP/1.1 200 OK\"\n",
602
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n",
603
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/v3/tunnel-request \"HTTP/1.1 200 OK\"\n"
604
+ ]
605
+ },
606
  {
607
  "name": "stdout",
608
  "output_type": "stream",
609
  "text": [
610
+ "* Running on public URL: https://de785d7cce806497e9.gradio.live\n",
611
  "\n",
612
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
613
+ ]
614
+ },
615
+ {
616
+ "name": "stderr",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "INFO:httpx:HTTP Request: HEAD https://de785d7cce806497e9.gradio.live \"HTTP/1.1 200 OK\"\n"
620
  ]
621
  },
622
  {
623
  "data": {
624
  "text/html": [
625
+ "<div><iframe src=\"https://de785d7cce806497e9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
626
  ],
627
  "text/plain": [
628
  "<IPython.core.display.HTML object>"
 
632
  "output_type": "display_data"
633
  },
634
  {
635
+ "name": "stderr",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "Traceback (most recent call last):\n",
639
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
640
+ " output = await route_utils.call_process_api(\n",
641
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
642
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
643
+ " output = await app.get_blocks().process_api(\n",
644
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
645
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
646
+ " result = await self.call_function(\n",
647
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
648
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
649
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
650
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
651
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
652
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
653
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
654
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
655
+ " return await future\n",
656
+ " ^^^^^^^^^^^^\n",
657
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
658
+ " result = context.run(func, *args)\n",
659
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
660
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
661
+ " response = f(*args, **kwargs)\n",
662
+ " ^^^^^^^^^^^^^^^^^^\n",
663
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
664
+ " ) + self.load_from_cache(example_id)\n",
665
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
666
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
667
+ " output.append(component.read_from_flag(value_to_use))\n",
668
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
669
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
670
+ " return self.data_model.from_json(json.loads(payload))\n",
671
+ " ^^^^^^^^^^^^^^^^^^^\n",
672
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
673
+ " return _default_decoder.decode(s)\n",
674
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
675
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
676
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
677
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
678
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
679
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
680
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n",
681
+ "Traceback (most recent call last):\n",
682
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
683
+ " output = await route_utils.call_process_api(\n",
684
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
685
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
686
+ " output = await app.get_blocks().process_api(\n",
687
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
688
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
689
+ " result = await self.call_function(\n",
690
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
691
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
692
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
693
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
694
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
695
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
696
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
697
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
698
+ " return await future\n",
699
+ " ^^^^^^^^^^^^\n",
700
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
701
+ " result = context.run(func, *args)\n",
702
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
703
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
704
+ " response = f(*args, **kwargs)\n",
705
+ " ^^^^^^^^^^^^^^^^^^\n",
706
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
707
+ " ) + self.load_from_cache(example_id)\n",
708
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
709
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
710
+ " output.append(component.read_from_flag(value_to_use))\n",
711
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
712
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
713
+ " return self.data_model.from_json(json.loads(payload))\n",
714
+ " ^^^^^^^^^^^^^^^^^^^\n",
715
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
716
+ " return _default_decoder.decode(s)\n",
717
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
718
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
719
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
720
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
721
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
722
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
723
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n",
724
+ "Traceback (most recent call last):\n",
725
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
726
+ " output = await route_utils.call_process_api(\n",
727
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
728
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
729
+ " output = await app.get_blocks().process_api(\n",
730
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
731
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
732
+ " result = await self.call_function(\n",
733
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
734
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
735
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
736
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
737
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
738
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
739
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
740
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
741
+ " return await future\n",
742
+ " ^^^^^^^^^^^^\n",
743
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
744
+ " result = context.run(func, *args)\n",
745
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
746
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
747
+ " response = f(*args, **kwargs)\n",
748
+ " ^^^^^^^^^^^^^^^^^^\n",
749
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
750
+ " ) + self.load_from_cache(example_id)\n",
751
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
752
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
753
+ " output.append(component.read_from_flag(value_to_use))\n",
754
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
755
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
756
+ " return self.data_model.from_json(json.loads(payload))\n",
757
+ " ^^^^^^^^^^^^^^^^^^^\n",
758
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
759
+ " return _default_decoder.decode(s)\n",
760
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
761
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
762
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
763
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
764
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
765
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
766
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n"
767
+ ]
768
  }
769
  ],
770
  "source": [
771
+ "from datetime import datetime\n",
772
  "import gradio as gr\n",
773
  "import requests\n",
774
+ "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select, Structure\n",
775
+ "from Bio.PDB.Polypeptide import is_aa\n",
776
+ "from Bio.SeqUtils import seq1\n",
777
+ "from typing import Optional, Tuple, Dict, List\n",
778
  "import numpy as np\n",
779
  "import os\n",
780
  "from gradio_molecule3d import Molecule3D\n",
781
+ "import torch\n",
782
+ "import torch.nn as nn\n",
783
+ "import torch.nn.functional as F\n",
784
+ "from torch.utils.data import DataLoader\n",
785
+ "import re\n",
786
+ "import pandas as pd\n",
787
+ "import copy\n",
788
+ "from scipy.special import expit\n",
789
+ "import logging\n",
790
+ "import tempfile\n",
791
  "\n",
792
+ "# Set up logging\n",
793
+ "logging.basicConfig(level=logging.INFO)\n",
794
+ "logger = logging.getLogger(__name__)\n",
 
795
  "\n",
796
+ "class StructureError(Exception):\n",
797
+ " \"\"\"Custom exception for structure-related errors\"\"\"\n",
798
+ " pass\n",
799
+ "\n",
800
+ "def normalize_scores(scores: np.ndarray) -> np.ndarray:\n",
801
+ " \"\"\"Normalize scores to range [0,1]\"\"\"\n",
802
+ " min_score = np.min(scores)\n",
803
+ " max_score = np.max(scores)\n",
804
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
805
+ "\n",
806
+ "def read_mol(pdb_path: str) -> str:\n",
807
+ " \"\"\"Read molecular structure file and return its content\"\"\"\n",
808
+ " try:\n",
809
+ " with open(pdb_path, 'r') as f:\n",
810
+ " return f.read()\n",
811
+ " except Exception as e:\n",
812
+ " raise IOError(f\"Failed to read structure file: {e}\")\n",
813
+ "\n",
814
+ "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
815
+ " \"\"\"Fetch structure file, trying multiple formats and sources\"\"\"\n",
816
+ " try:\n",
817
+ " # First try local cache\n",
818
+ " for ext in ['.cif', '.pdb']:\n",
819
+ " local_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n",
820
+ " if os.path.exists(local_path):\n",
821
+ " logger.info(f\"Using cached structure: {local_path}\")\n",
822
+ " return local_path\n",
823
+ "\n",
824
+ " # Try different download sources\n",
825
+ " sources = [\n",
826
+ " f\"https://files.rcsb.org/download/{pdb_id.upper()}.cif\",\n",
827
+ " f\"https://files.rcsb.org/download/{pdb_id.upper()}.pdb\",\n",
828
+ " f\"https://files.rcsb.org/download/{pdb_id.lower()}.cif\",\n",
829
+ " f\"https://files.rcsb.org/download/{pdb_id.lower()}.pdb\"\n",
830
+ " ]\n",
831
+ "\n",
832
+ " for url in sources:\n",
833
+ " try:\n",
834
+ " response = requests.get(url, timeout=10)\n",
835
+ " if response.status_code == 200:\n",
836
+ " ext = '.cif' if 'cif' in url else '.pdb'\n",
837
+ " file_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n",
838
+ " with open(file_path, 'wb') as f:\n",
839
+ " f.write(response.content)\n",
840
+ " logger.info(f\"Successfully downloaded: {url}\")\n",
841
+ " return file_path\n",
842
+ " except Exception as e:\n",
843
+ " logger.warning(f\"Failed to download from {url}: {e}\")\n",
844
+ " continue\n",
845
+ "\n",
846
+ " raise StructureError(f\"Failed to fetch structure for PDB ID: {pdb_id}\")\n",
847
+ " except Exception as e:\n",
848
+ " raise StructureError(f\"Error fetching structure: {e}\")\n",
849
+ "\n",
850
+ "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
851
+ " \"\"\"Convert CIF to PDB format with error handling\"\"\"\n",
852
+ " try:\n",
853
+ " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
854
+ " parser = MMCIFParser(QUIET=True)\n",
855
+ " structure = parser.get_structure('protein', cif_path)\n",
856
+ " io = PDBIO()\n",
857
+ " io.set_structure(structure)\n",
858
+ " io.save(pdb_path)\n",
859
  " return pdb_path\n",
860
+ " except Exception as e:\n",
861
+ " raise StructureError(f\"Failed to convert CIF to PDB: {e}\")\n",
862
+ "\n",
863
+ "def find_valid_chain(structure: Structure.Structure) -> Optional[str]:\n",
864
+ " \"\"\"Find the first valid protein chain in the structure\"\"\"\n",
865
+ " for model in structure:\n",
866
+ " for chain in model:\n",
867
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
868
+ " if len(protein_residues) > 0:\n",
869
+ " return chain.id\n",
870
+ " return None\n",
871
+ "\n",
872
+ "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
873
+ " \"\"\"Create PDB file with selected chain and prediction scores in B-factor column\"\"\"\n",
874
+ " class ResidueSelector(Select):\n",
875
+ " def __init__(self, chain_id, selected_residues, scores_dict):\n",
876
+ " self.chain_id = chain_id\n",
877
+ " self.selected_residues = selected_residues\n",
878
+ " self.scores_dict = scores_dict\n",
879
+ " \n",
880
+ " def accept_chain(self, chain):\n",
881
+ " return chain.id == self.chain_id\n",
882
+ " \n",
883
+ " def accept_residue(self, residue):\n",
884
+ " return residue.id[1] in self.selected_residues\n",
885
+ "\n",
886
+ " def accept_atom(self, atom):\n",
887
+ " if atom.parent.id[1] in self.scores_dict:\n",
888
+ " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n",
889
+ " return True\n",
890
  "\n",
 
 
 
 
 
 
 
891
  " try:\n",
892
+ " parser = PDBParser(QUIET=True)\n",
893
+ " structure = parser.get_structure('protein', input_pdb)\n",
894
+ " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
895
+ " scores_dict = {resi: score for resi, score in residue_scores}\n",
896
+ " \n",
897
+ " io = PDBIO()\n",
898
+ " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
899
+ " io.set_structure(structure[0])\n",
900
+ " io.save(output_pdb, selector)\n",
901
+ " \n",
902
+ " return output_pdb\n",
903
+ " except Exception as e:\n",
904
+ " raise StructureError(f\"Failed to create chain-specific PDB: {e}\")\n",
905
+ "\n",
906
+ "def process_pdb(pdb_id_or_file: str, segment: str) -> Tuple[str, str, List[str]]:\n",
907
+ " \"\"\"Process PDB/CIF file and generate visualizations and predictions\"\"\"\n",
908
+ " try:\n",
909
+ " # Handle input\n",
910
+ " if pdb_id_or_file.endswith(('.pdb', '.cif')):\n",
911
+ " pdb_path = pdb_id_or_file\n",
912
+ " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
913
+ " else:\n",
914
+ " pdb_id = pdb_id_or_file\n",
915
+ " pdb_path = fetch_structure(pdb_id)\n",
916
+ "\n",
917
+ " if not pdb_path:\n",
918
+ " raise StructureError(\"Failed to obtain structure file\")\n",
919
+ "\n",
920
+ " # Parse structure\n",
921
+ " parser = MMCIFParser(QUIET=True) if pdb_path.endswith('.cif') else PDBParser(QUIET=True)\n",
922
+ " structure = parser.get_structure('protein', pdb_path)\n",
923
+ "\n",
924
+ " # Handle chain selection\n",
925
+ " if segment == 'auto' or not segment:\n",
926
+ " segment = find_valid_chain(structure)\n",
927
+ " if not segment:\n",
928
+ " raise StructureError(\"No valid protein chains found in structure\")\n",
929
+ " \n",
930
+ " try:\n",
931
+ " chain = structure[0][segment]\n",
932
+ " except KeyError:\n",
933
+ " valid_chain = find_valid_chain(structure)\n",
934
+ " if valid_chain:\n",
935
+ " chain = structure[0][valid_chain]\n",
936
+ " segment = valid_chain\n",
937
+ " logger.info(f\"Using alternative chain {segment}\")\n",
938
+ " else:\n",
939
+ " raise StructureError(f\"Invalid chain ID '{segment}'. Structure has no valid protein chains.\")\n",
940
+ "\n",
941
+ " # Process chain\n",
942
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
943
+ " if not protein_residues:\n",
944
+ " raise StructureError(f\"No amino acid residues found in chain {segment}\")\n",
945
+ "\n",
946
+ " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
947
+ " sequence_id = [res.id[1] for res in protein_residues]\n",
948
+ " \n",
949
+ " # Generate predictions (currently random)\n",
950
+ " scores = np.random.rand(len(sequence))\n",
951
+ " normalized_scores = normalize_scores(scores)\n",
952
+ " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
953
+ "\n",
954
+ " # Generate outputs\n",
955
+ " result_str = generate_results_string(pdb_id, segment, protein_residues, normalized_scores, sequence)\n",
956
+ " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
957
+ " mol_vis = molecule(pdb_path, residue_scores, segment)\n",
958
+ " pymol_commands = generate_pymol_commands(pdb_id, segment, residue_scores, pdb_path)\n",
959
+ "\n",
960
+ " # Save results\n",
961
+ " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
962
+ " with open(prediction_file, \"w\") as f:\n",
963
+ " f.write(result_str)\n",
964
+ "\n",
965
+ " return pymol_commands, mol_vis, [prediction_file, scored_pdb]\n",
966
+ "\n",
967
+ " except StructureError as e:\n",
968
+ " return str(e), None, None\n",
969
+ " except Exception as e:\n",
970
+ " return f\"An unexpected error occurred: {str(e)}\", None, None\n",
971
+ "\n",
972
+ "def generate_results_string(pdb_id: str, segment: str, protein_residues: list, \n",
973
+ " normalized_scores: np.ndarray, sequence: str) -> str:\n",
974
+ " \"\"\"Generate formatted results string with predictions\"\"\"\n",
975
+ " score_brackets = {\n",
976
+ " \"0.0-0.2\": (0.0, 0.2),\n",
977
+ " \"0.2-0.4\": (0.2, 0.4),\n",
978
+ " \"0.4-0.6\": (0.4, 0.6),\n",
979
+ " \"0.6-0.8\": (0.6, 0.8),\n",
980
+ " \"0.8-1.0\": (0.8, 1.0)\n",
981
+ " }\n",
982
  " \n",
983
+ " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n",
984
+ " \n",
985
+ " # Categorize residues\n",
986
+ " for i, score in enumerate(normalized_scores):\n",
987
+ " for bracket, (lower, upper) in score_brackets.items():\n",
988
+ " if lower <= score < upper:\n",
989
+ " residues_by_bracket[bracket].append(protein_residues[i])\n",
990
+ " break\n",
991
+ " \n",
992
+ " # Format results\n",
993
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
994
+ " result_str = f\"\"\"Prediction Results\n",
995
+ "========================\n",
996
+ "PDB: {pdb_id}\n",
997
+ "Chain: {segment}\n",
998
+ "Date: {current_time}\n",
999
  "\n",
1000
+ "Analysis by Score Brackets\n",
1001
+ "========================\n",
1002
+ "\"\"\"\n",
1003
  " \n",
1004
+ " for bracket, residues in residues_by_bracket.items():\n",
1005
+ " if residues: # Only show brackets with residues\n",
1006
+ " result_str += f\"\\nBracket {bracket}:\\n\"\n",
1007
+ " result_str += \"ResName ResNum Code Score\\n\"\n",
1008
+ " result_str += \"-\" * 30 + \"\\n\"\n",
1009
+ " result_str += \"\\n\".join([\n",
1010
+ " f\"{res.resname:6} {res.id[1]:6} {sequence[i]:4} {normalized_scores[i]:6.2f}\" \n",
1011
+ " for i, res in enumerate(protein_residues) if res in residues\n",
1012
+ " ])\n",
1013
+ " result_str += \"\\n\"\n",
1014
+ " \n",
1015
+ " return result_str\n",
1016
+ "\n",
1017
+ "def generate_pymol_commands(pdb_id: str, segment: str, residue_scores: list, pdb_path: str) -> str:\n",
1018
+ " \"\"\"Generate PyMOL visualization commands\"\"\"\n",
1019
+ " # Group residues by score ranges\n",
1020
+ " score_groups = {\n",
1021
+ " \"very_low\": [], \"low\": [], \"medium\": [], \"high\": [], \"very_high\": []\n",
1022
+ " }\n",
1023
  " \n",
1024
+ " for resi, score in residue_scores:\n",
1025
+ " if score <= 0.2:\n",
1026
+ " score_groups[\"very_low\"].append(str(resi))\n",
1027
+ " elif score <= 0.4:\n",
1028
+ " score_groups[\"low\"].append(str(resi))\n",
1029
+ " elif score <= 0.6:\n",
1030
+ " score_groups[\"medium\"].append(str(resi))\n",
1031
+ " elif score <= 0.8:\n",
1032
+ " score_groups[\"high\"].append(str(resi))\n",
1033
+ " else:\n",
1034
+ " score_groups[\"very_high\"].append(str(resi))\n",
1035
+ "\n",
1036
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
1037
+ " commands = f\"\"\"# PyMOL Script for {pdb_id} Chain {segment}\n",
1038
+ "# Generated: {current_time}\n",
1039
+ "\n",
1040
+ "# Load structure and set initial display\n",
1041
+ "load {os.path.abspath(pdb_path)}, protein\n",
1042
+ "bg_color white\n",
1043
+ "hide everything\n",
1044
+ "show cartoon, chain {segment}\n",
1045
+ "color white, chain {segment}\n",
1046
+ "\n",
1047
+ "# Create selection groups by score\n",
1048
+ "\"\"\"\n",
1049
+ " \n",
1050
+ " color_scheme = {\n",
1051
+ " \"very_low\": \"white\",\n",
1052
+ " \"low\": \"lightorange\",\n",
1053
+ " \"medium\": \"orange\",\n",
1054
+ " \"high\": \"orangered\",\n",
1055
+ " \"very_high\": \"red\"\n",
1056
+ " }\n",
1057
+ " \n",
1058
+ " for group, residues in score_groups.items():\n",
1059
+ " if residues:\n",
1060
+ " resi_str = \"+\".join(residues)\n",
1061
+ " commands += f\"\"\"\n",
1062
+ "# {group.replace('_', ' ').title()} scoring residues\n",
1063
+ "select {group}, chain {segment} and resi {resi_str}\n",
1064
+ "show sticks, {group}\n",
1065
+ "color {color_scheme[group]}, {group}\"\"\"\n",
1066
+ " \n",
1067
+ " commands += \"\"\"\n",
1068
+ "\n",
1069
+ "# Center and zoom\n",
1070
+ "center chain {}\n",
1071
+ "zoom chain {}\n",
1072
+ "\"\"\"\n",
1073
+ "\n",
1074
+ " return commands\n",
1075
+ "\n",
1076
+ "def molecule(input_pdb: str, residue_scores: list = None, segment: str = 'A') -> str:\n",
1077
+ " \"\"\"Generate interactive 3D molecule visualization\"\"\"\n",
1078
+ " try:\n",
1079
+ " mol = read_mol(input_pdb)\n",
1080
+ " except Exception as e:\n",
1081
+ " return f'<div class=\"error\">Error loading structure: {str(e)}</div>'\n",
1082
+ "\n",
1083
+ " # Prepare residue groups for visualization\n",
1084
+ " vis_groups = {\n",
1085
+ " \"class1\": [], # 0.0-0.2\n",
1086
+ " \"class2\": [], # 0.2-0.4\n",
1087
+ " \"class3\": [], # 0.4-0.6\n",
1088
+ " \"class4\": [], # 0.6-0.8\n",
1089
+ " \"class5\": [] # 0.8-1.0\n",
1090
+ " }\n",
1091
+ "\n",
1092
+ " if residue_scores:\n",
1093
+ " for resi, score in residue_scores:\n",
1094
+ " if score <= 0.2:\n",
1095
+ " vis_groups[\"class1\"].append(resi)\n",
1096
+ " elif score <= 0.4:\n",
1097
+ " vis_groups[\"class2\"].append(resi)\n",
1098
+ " elif score <= 0.6:\n",
1099
+ " vis_groups[\"class3\"].append(resi)\n",
1100
+ " elif score <= 0.8:\n",
1101
+ " vis_groups[\"class4\"].append(resi)\n",
1102
+ " else:\n",
1103
+ " vis_groups[\"class5\"].append(resi)\n",
1104
+ "\n",
1105
+ " # Generate visualization script\n",
1106
+ " vis_script = f\"\"\"\n",
1107
+ " // Base model setup\n",
1108
+ " let chainModel = viewer.addModel(pdb, \"pdb\");\n",
1109
+ " chainModel.setStyle({{}}, {{}});\n",
1110
+ " chainModel.setStyle(\n",
1111
+ " {{\"chain\": \"{segment}\"}}, \n",
1112
+ " {{\"cartoon\": {{\"color\": \"white\"}}}}\n",
1113
+ " );\n",
1114
+ " \"\"\"\n",
1115
+ "\n",
1116
+ " # Color schemes for different score ranges\n",
1117
+ " color_schemes = {\n",
1118
+ " \"class1\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}, # White\n",
1119
+ " \"class2\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}, # Light orange\n",
1120
+ " \"class3\": {\"color\": \"0xFFA500\", \"opacity\": 1.0}, # Orange\n",
1121
+ " \"class4\": {\"color\": \"0xFF4500\", \"opacity\": 1.0}, # Orange red\n",
1122
+ " \"class5\": {\"color\": \"0xFF0000\", \"opacity\": 1.0} # Red\n",
1123
+ " }\n",
1124
+ "\n",
1125
+ " # Add visualization for each group\n",
1126
+ " for group, residues in vis_groups.items():\n",
1127
+ " if residues:\n",
1128
+ " color_scheme = color_schemes[group]\n",
1129
+ " vis_script += f\"\"\"\n",
1130
+ " let {group}Model = viewer.addModel(pdb, \"pdb\");\n",
1131
+ " {group}Model.setStyle({{}}, {{}});\n",
1132
+ " {group}Model.setStyle(\n",
1133
+ " {{\"chain\": \"{segment}\", \"resi\": [{\", \".join(map(str, residues))}]}},\n",
1134
+ " {{\"stick\": {{\"color\": \"{color_scheme[\"color\"]}\", \"opacity\": {color_scheme[\"opacity\"]}}}}}\n",
1135
+ " );\n",
1136
+ " \"\"\"\n",
1137
+ "\n",
1138
+ " # Generate full HTML with enhanced controls and information\n",
1139
  " html_content = f\"\"\"\n",
1140
  " <!DOCTYPE html>\n",
1141
  " <html>\n",
 
1147
  " height: 700px;\n",
1148
  " position: relative;\n",
1149
  " }}\n",
1150
+ " .controls {{\n",
1151
+ " position: absolute;\n",
1152
+ " top: 10px;\n",
1153
+ " left: 10px;\n",
1154
+ " background: rgba(255, 255, 255, 0.8);\n",
1155
+ " padding: 10px;\n",
1156
+ " border-radius: 5px;\n",
1157
+ " z-index: 1000;\n",
1158
+ " }}\n",
1159
+ " .legend {{\n",
1160
+ " position: absolute;\n",
1161
+ " bottom: 10px;\n",
1162
+ " right: 10px;\n",
1163
+ " background: rgba(255, 255, 255, 0.8);\n",
1164
+ " padding: 10px;\n",
1165
+ " border-radius: 5px;\n",
1166
+ " z-index: 1000;\n",
1167
+ " }}\n",
1168
+ " .error {{\n",
1169
+ " color: red;\n",
1170
+ " padding: 20px;\n",
1171
+ " text-align: center;\n",
1172
+ " font-weight: bold;\n",
1173
+ " }}\n",
1174
  " </style>\n",
1175
  " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
1176
  " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
1177
  " </head>\n",
1178
  " <body>\n",
1179
+ " <div id=\"container\" class=\"mol-container\">\n",
1180
+ " <div class=\"controls\">\n",
1181
+ " <button onclick=\"toggleStyle('cartoon')\">Toggle Cartoon</button>\n",
1182
+ " <button onclick=\"toggleStyle('stick')\">Toggle Sticks</button>\n",
1183
+ " <button onclick=\"resetView()\">Reset View</button>\n",
1184
+ " <button onclick=\"toggleSpin()\">Toggle Spin</button>\n",
1185
+ " </div>\n",
1186
+ " <div class=\"legend\">\n",
1187
+ " <div><span style=\"color: #FF0000\">■</span> Very High (0.8-1.0)</div>\n",
1188
+ " <div><span style=\"color: #FF4500\">■</span> High (0.6-0.8)</div>\n",
1189
+ " <div><span style=\"color: #FFA500\">■</span> Medium (0.4-0.6)</div>\n",
1190
+ " <div><span style=\"color: #FFD580\">■</span> Low (0.2-0.4)</div>\n",
1191
+ " <div><span style=\"color: #FFFFFF\">■</span> Very Low (0.0-0.2)</div>\n",
1192
+ " </div>\n",
1193
+ " </div>\n",
1194
  " <script>\n",
1195
+ " let pdb = `{mol}`;\n",
1196
+ " let viewer;\n",
1197
+ " let isSpinning = false;\n",
1198
+ "\n",
1199
  " $(document).ready(function () {{\n",
1200
  " let element = $(\"#container\");\n",
1201
  " let config = {{ backgroundColor: \"white\" }};\n",
1202
+ " viewer = $3Dmol.createViewer(element, config);\n",
 
1203
  " \n",
1204
+ " {vis_script}\n",
 
1205
  " \n",
1206
+ " // Enhanced hover functionality\n",
1207
+ " viewer.setHoverable(\n",
1208
+ " {{}}, \n",
1209
+ " true, \n",
1210
+ " function(atom, viewer, event, container) {{\n",
1211
+ " if (!atom.label) {{\n",
1212
+ " atom.label = viewer.addLabel(\n",
1213
+ " `${{atom.resn}}:${{atom.resi}}:${{atom.atom}}`, \n",
1214
+ " {{\n",
1215
+ " position: atom, \n",
1216
+ " backgroundColor: 'mintcream', \n",
1217
+ " fontColor: 'black',\n",
1218
+ " fontSize: 18,\n",
1219
+ " padding: 4\n",
1220
+ " }}\n",
1221
+ " );\n",
1222
+ " }}\n",
1223
+ " }},\n",
1224
+ " function(atom, viewer) {{\n",
1225
+ " if (atom.label) {{\n",
1226
+ " viewer.removeLabel(atom.label);\n",
1227
+ " delete atom.label;\n",
1228
+ " }}\n",
1229
+ " }}\n",
1230
+ " );\n",
1231
  " \n",
1232
  " viewer.zoomTo();\n",
1233
  " viewer.render();\n",
1234
  " viewer.zoom(0.8, 2000);\n",
1235
  " }});\n",
1236
+ "\n",
1237
+ " function toggleStyle(style) {{\n",
1238
+ " let elements = viewer.selectedAtoms({{}});\n",
1239
+ " let currentStyle = elements.style[style];\n",
1240
+ " elements.style[style] = !currentStyle;\n",
1241
+ " viewer.render();\n",
1242
+ " }}\n",
1243
+ "\n",
1244
+ " function resetView() {{\n",
1245
+ " viewer.zoomTo();\n",
1246
+ " viewer.render();\n",
1247
+ " }}\n",
1248
+ "\n",
1249
+ " function toggleSpin() {{\n",
1250
+ " isSpinning = !isSpinning;\n",
1251
+ " viewer.spin(isSpinning);\n",
1252
+ " }}\n",
1253
  " </script>\n",
1254
  " </body>\n",
1255
  " </html>\n",
1256
  " \"\"\"\n",
1257
  " \n",
 
1258
  " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
1259
  "\n",
 
 
 
 
 
 
 
 
 
 
1260
  "# Gradio UI\n",
1261
+ "def create_ui():\n",
1262
+ " with gr.Blocks(title=\"Protein Binding Site Prediction\", theme=gr.themes.Base()) as demo:\n",
1263
+ " gr.Markdown(\"\"\"\n",
1264
+ " # Protein Binding Site Prediction\n",
1265
+ " \n",
1266
+ " This tool helps you visualize and analyze potential binding sites in protein structures.\n",
1267
+ " You can either:\n",
1268
+ " 1. Enter a PDB ID (e.g., \"4BDU\")\n",
1269
+ " 2. Upload your own PDB/CIF file\n",
1270
+ " \n",
1271
+ " The tool will analyze the structure and show predictions using a color gradient from white (low probability) to red (high probability).\n",
1272
+ " \"\"\")\n",
1273
+ " \n",
1274
+ " with gr.Row():\n",
1275
+ " with gr.Column(scale=2):\n",
1276
+ " # Input components\n",
1277
+ " mode = gr.Radio(\n",
1278
+ " choices=[\"PDB ID\", \"Upload File\"],\n",
1279
+ " value=\"PDB ID\",\n",
1280
+ " label=\"Input Mode\",\n",
1281
+ " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file\"\n",
1282
+ " )\n",
1283
+ " \n",
1284
+ " with gr.Group():\n",
1285
+ " pdb_input = gr.Textbox(\n",
1286
+ " value=\"4BDU\",\n",
1287
+ " label=\"PDB ID\",\n",
1288
+ " placeholder=\"Enter PDB ID (e.g., 4BDU)\",\n",
1289
+ " info=\"Enter a valid PDB ID from the Protein Data Bank\"\n",
1290
+ " )\n",
1291
+ " pdb_file = gr.File(\n",
1292
+ " label=\"Upload PDB/CIF File\",\n",
1293
+ " file_types=[\".pdb\", \".cif\"],\n",
1294
+ " visible=False\n",
1295
+ " )\n",
1296
+ " \n",
1297
+ " segment_input = gr.Textbox(\n",
1298
+ " value=\"A\",\n",
1299
+ " label=\"Chain ID\",\n",
1300
+ " placeholder=\"Enter Chain ID or leave empty for automatic selection\",\n",
1301
+ " info=\"Specify which protein chain to analyze, or leave empty for automatic selection\"\n",
1302
+ " )\n",
1303
  "\n",
1304
+ " with gr.Column(scale=1):\n",
1305
+ " visualize_btn = gr.Button(\"Visualize Structure\", variant=\"primary\")\n",
1306
+ " prediction_btn = gr.Button(\"Predict Binding Site\", variant=\"secondary\")\n",
1307
+ " \n",
1308
+ " gr.Markdown(\"\"\"\n",
1309
+ " ### Color Legend\n",
1310
+ " - White: Very Low (0.0-0.2)\n",
1311
+ " - Light Orange: Low (0.2-0.4)\n",
1312
+ " - Orange: Medium (0.4-0.6)\n",
1313
+ " - Orange Red: High (0.6-0.8)\n",
1314
+ " - Red: Very High (0.8-1.0)\n",
1315
+ " \"\"\")\n",
1316
  "\n",
1317
+ " with gr.Tab(\"3D Visualization\"):\n",
1318
+ " molecule_output = gr.HTML(label=\"Interactive 3D Structure\")\n",
1319
+ " \n",
1320
+ " with gr.Tab(\"Analysis Results\"):\n",
1321
+ " predictions_output = gr.Textbox(\n",
1322
+ " label=\"PyMOL Visualization Commands\",\n",
1323
+ " info=\"Copy these commands into PyMOL to recreate the visualization\"\n",
1324
+ " )\n",
1325
+ " download_output = gr.File(\n",
1326
+ " label=\"Download Results\",\n",
1327
+ " file_count=\"multiple\"\n",
1328
+ " )\n",
1329
  "\n",
1330
+ " # Error message container\n",
1331
+ " error_output = gr.Markdown(visible=False)\n",
1332
+ "\n",
1333
+ " # Mode change handler\n",
1334
+ " def toggle_mode(selected_mode):\n",
1335
+ " return {\n",
1336
+ " pdb_input: gr.update(visible=selected_mode == \"PDB ID\"),\n",
1337
+ " pdb_file: gr.update(visible=selected_mode == \"Upload File\")\n",
1338
+ " }\n",
1339
+ "\n",
1340
+ " mode.change(\n",
1341
+ " toggle_mode,\n",
1342
+ " inputs=[mode],\n",
1343
+ " outputs=[pdb_input, pdb_file]\n",
1344
+ " )\n",
1345
+ "\n",
1346
+ " # Process handlers\n",
1347
+ " def handle_visualization(mode, pdb_id, pdb_file):\n",
1348
+ " try:\n",
1349
+ " result = fetch_interface(mode, pdb_id, pdb_file)\n",
1350
+ " if isinstance(result, str) and result.startswith(\"Error\"):\n",
1351
+ " return None, gr.update(visible=True, value=f\"```\\n{result}\\n```\")\n",
1352
+ " return result, gr.update(visible=False)\n",
1353
+ " except Exception as e:\n",
1354
+ " return None, gr.update(visible=True, value=f\"```\\nError: {str(e)}\\n```\")\n",
1355
+ "\n",
1356
+ " def handle_prediction(mode, pdb_id, pdb_file, chain_id):\n",
1357
+ " try:\n",
1358
+ " predictions, vis, downloads = process_interface(mode, pdb_id, pdb_file, chain_id)\n",
1359
+ " if isinstance(predictions, str) and \"Error\" in predictions:\n",
1360
+ " return (\n",
1361
+ " None,\n",
1362
+ " None,\n",
1363
+ " None,\n",
1364
+ " gr.update(visible=True, value=f\"```\\n{predictions}\\n```\")\n",
1365
+ " )\n",
1366
+ " return (\n",
1367
+ " predictions,\n",
1368
+ " vis,\n",
1369
+ " downloads,\n",
1370
+ " gr.update(visible=False)\n",
1371
+ " )\n",
1372
+ " except Exception as e:\n",
1373
+ " error_msg = f\"\"\"Error processing structure:\n",
1374
+ "```\n",
1375
+ "{str(e)}\n",
1376
+ "\n",
1377
+ "Troubleshooting tips:\n",
1378
+ "1. Check if the PDB ID is valid\n",
1379
+ "2. Ensure the Chain ID exists in the structure\n",
1380
+ "3. Try leaving Chain ID empty for automatic selection\n",
1381
+ "4. If uploading a file, ensure it's a valid PDB/CIF format\n",
1382
+ "```\"\"\"\n",
1383
+ " return None, None, None, gr.update(visible=True, value=error_msg)\n",
1384
+ "\n",
1385
+ " def fetch_interface(mode, pdb_id, pdb_file):\n",
1386
+ " try:\n",
1387
+ " if mode == \"PDB ID\":\n",
1388
+ " if not pdb_id or len(pdb_id.strip()) != 4:\n",
1389
+ " raise ValueError(\"Please enter a valid 4-character PDB ID\")\n",
1390
+ " return fetch_pdb(pdb_id.strip())\n",
1391
+ " elif mode == \"Upload File\":\n",
1392
+ " if not pdb_file:\n",
1393
+ " raise ValueError(\"Please upload a PDB or CIF file\")\n",
1394
+ " _, ext = os.path.splitext(pdb_file.name)\n",
1395
+ " if ext.lower() not in ['.pdb', '.cif']:\n",
1396
+ " raise ValueError(\"Only .pdb and .cif files are supported\")\n",
1397
+ " \n",
1398
+ " # Create temp directory for file handling\n",
1399
+ " with tempfile.TemporaryDirectory() as temp_dir:\n",
1400
+ " temp_path = os.path.join(temp_dir, os.path.basename(pdb_file.name))\n",
1401
+ " with open(temp_path, 'wb') as f:\n",
1402
+ " f.write(pdb_file.read())\n",
1403
+ " \n",
1404
+ " if ext.lower() == '.cif':\n",
1405
+ " return convert_cif_to_pdb(temp_path)\n",
1406
+ " return temp_path\n",
1407
+ " else:\n",
1408
+ " raise ValueError(\"Invalid mode selected\")\n",
1409
+ " except Exception as e:\n",
1410
+ " return f\"Error: {str(e)}\"\n",
1411
+ "\n",
1412
+ " # Connect event handlers\n",
1413
+ " visualize_btn.click(\n",
1414
+ " handle_visualization,\n",
1415
+ " inputs=[mode, pdb_input, pdb_file],\n",
1416
+ " outputs=[molecule_output, error_output]\n",
1417
+ " )\n",
1418
  "\n",
1419
+ " prediction_btn.click(\n",
1420
+ " handle_prediction,\n",
1421
+ " inputs=[mode, pdb_input, pdb_file, segment_input],\n",
1422
+ " outputs=[predictions_output, molecule_output, download_output, error_output]\n",
1423
+ " )\n",
1424
+ "\n",
1425
+ " # Add examples\n",
1426
+ " gr.Examples(\n",
1427
+ " examples=[\n",
1428
+ " [\"PDB ID\", \"7RPZ\", None, \"A\"],\n",
1429
+ " [\"PDB ID\", \"2IWI\", None, \"B\"],\n",
1430
+ " [\"PDB ID\", \"2F6V\", None, \"A\"]\n",
1431
+ " ],\n",
1432
+ " inputs=[mode, pdb_input, pdb_file, segment_input],\n",
1433
+ " outputs=[predictions_output, molecule_output, download_output, error_output],\n",
1434
+ " fn=handle_prediction,\n",
1435
+ " cache_examples=True\n",
1436
+ " )\n",
1437
+ "\n",
1438
+ " # Add documentation\n",
1439
+ " gr.Markdown(\"\"\"\n",
1440
+ " ## Usage Instructions\n",
1441
+ " \n",
1442
+ " 1. **Input Structure:**\n",
1443
+ " - Enter a PDB ID (e.g., \"4BDU\") or upload your own structure file\n",
1444
+ " - The tool supports both PDB (.pdb) and mmCIF (.cif) formats\n",
1445
+ " \n",
1446
+ " 2. **Select Chain:**\n",
1447
+ " - Enter a specific chain ID (e.g., \"A\")\n",
1448
+ " - Leave empty for automatic selection of the first valid protein chain\n",
1449
+ " \n",
1450
+ " 3. **Analyze:**\n",
1451
+ " - Click \"Visualize Structure\" to view the 3D structure\n",
1452
+ " - Click \"Predict Binding Site\" to perform binding site analysis\n",
1453
+ " \n",
1454
+ " 4. **Results:**\n",
1455
+ " - Interactive 3D visualization with color-coded predictions\n",
1456
+ " - PyMOL commands for external visualization\n",
1457
+ " - Downloadable results files\n",
1458
+ " \n",
1459
+ " ## Troubleshooting\n",
1460
+ " \n",
1461
+ " If you encounter issues:\n",
1462
+ " 1. Ensure your PDB ID is valid and exists in the PDB database\n",
1463
+ " 2. Check that your uploaded file is a valid PDB/CIF format\n",
1464
+ " 3. Try automatic chain selection if your specified chain isn't found\n",
1465
+ " 4. Clear your browser cache if visualizations don't load\n",
1466
+ " \"\"\")\n",
1467
+ "\n",
1468
+ " return demo\n",
1469
+ "\n",
1470
+ "if __name__ == \"__main__\":\n",
1471
+ " demo = create_ui()\n",
1472
+ " demo.launch(share=True)"
1473
  ]
1474
  },
1475
  {
1476
  "cell_type": "code",
1477
  "execution_count": null,
1478
+ "id": "9125d1c4-e2ae-4e40-ba36-7ae944512b8e",
1479
+ "metadata": {},
1480
+ "outputs": [],
1481
+ "source": []
1482
+ },
1483
+ {
1484
+ "cell_type": "code",
1485
+ "execution_count": null,
1486
+ "id": "85c0728a-a15b-4118-b920-5f55a2f5f79a",
1487
  "metadata": {},
1488
  "outputs": [],
1489
  "source": []
 
1505
  "name": "python",
1506
  "nbconvert_exporter": "python",
1507
  "pygments_lexer": "ipython3",
1508
+ "version": "3.12.2"
1509
  }
1510
  },
1511
  "nbformat": 4,
app.py CHANGED
@@ -139,30 +139,6 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
139
 
140
  return output_pdb
141
 
142
- def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):
143
- """
144
- Calculate the geometric center of high-scoring residues
145
- """
146
- parser = PDBParser(QUIET=True)
147
- structure = parser.get_structure('protein', pdb_path)
148
-
149
- # Collect coordinates of CA atoms from high-scoring residues
150
- coords = []
151
- for model in structure:
152
- for chain in model:
153
- if chain.id == chain_id:
154
- for residue in chain:
155
- if residue.id[1] in high_score_residues:
156
- if 'CA' in residue: # Use alpha carbon as representative
157
- ca_atom = residue['CA']
158
- coords.append(ca_atom.coord)
159
-
160
- # Calculate geometric center
161
- if coords:
162
- center = np.mean(coords, axis=0)
163
- return center
164
- return None
165
-
166
  def process_pdb(pdb_id_or_file, segment):
167
  # Determine if input is a PDB ID or file path
168
  if pdb_id_or_file.endswith('.pdb'):
@@ -194,7 +170,11 @@ def process_pdb(pdb_id_or_file, segment):
194
  protein_residues = [res for res in chain if is_aa(res)]
195
  sequence = "".join(seq1(res.resname) for res in protein_residues)
196
  sequence_id = [res.id[1] for res in protein_residues]
197
-
 
 
 
 
198
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
199
  with torch.no_grad():
200
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
@@ -300,7 +280,6 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
300
  class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
301
  class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
302
 
303
-
304
  high_score_script = """
305
  // Load the original model and apply white cartoon style
306
  let chainModel = viewer.addModel(pdb, "pdb");
@@ -430,7 +409,19 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
430
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
431
 
432
  # Gradio UI
433
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
434
  gr.Markdown("# Protein Binding Site Prediction")
435
 
436
  # Mode selection
@@ -442,9 +433,9 @@ with gr.Blocks() as demo:
442
  )
443
 
444
  # Input components based on mode
445
- pdb_input = gr.Textbox(value="4BDU", label="PDB ID", placeholder="Enter PDB ID here...")
446
  pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
447
- visualize_btn = gr.Button("Visualize Structure")
448
 
449
  molecule_output2 = Molecule3D(label="Protein Structure", reps=[
450
  {
@@ -458,8 +449,9 @@ with gr.Blocks() as demo:
458
  ])
459
 
460
  with gr.Row():
461
- segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
462
- prediction_btn = gr.Button("Predict Binding Site")
 
463
 
464
  molecule_output = gr.HTML(label="Protein Structure")
465
  explanation_vis = gr.Markdown("""
@@ -533,7 +525,7 @@ with gr.Blocks() as demo:
533
  examples=[
534
  ["7RPZ", "A"],
535
  ["2IWI", "B"],
536
- ["2F6V", "A"]
537
  ],
538
  inputs=[pdb_input, segment_input],
539
  outputs=[predictions_output, molecule_output, download_output]
 
139
 
140
  return output_pdb
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def process_pdb(pdb_id_or_file, segment):
143
  # Determine if input is a PDB ID or file path
144
  if pdb_id_or_file.endswith('.pdb'):
 
170
  protein_residues = [res for res in chain if is_aa(res)]
171
  sequence = "".join(seq1(res.resname) for res in protein_residues)
172
  sequence_id = [res.id[1] for res in protein_residues]
173
+
174
+ visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
175
+ if sequence != visualized_sequence:
176
+ raise ValueError("The visualized sequence does not match the prediction sequence")
177
+
178
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
179
  with torch.no_grad():
180
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
 
280
  class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
281
  class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
282
 
 
283
  high_score_script = """
284
  // Load the original model and apply white cartoon style
285
  let chainModel = viewer.addModel(pdb, "pdb");
 
409
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
410
 
411
  # Gradio UI
412
+ with gr.Blocks(css="""
413
+ /* Customize Gradio button colors */
414
+ #visualize-btn, #predict-btn {
415
+ background-color: #FF7300; /* Deep orange */
416
+ color: white;
417
+ border-radius: 5px;
418
+ padding: 10px;
419
+ font-weight: bold;
420
+ }
421
+ #visualize-btn:hover, #predict-btn:hover {
422
+ background-color: #CC5C00; /* Darkened orange on hover */
423
+ }
424
+ """) as demo:
425
  gr.Markdown("# Protein Binding Site Prediction")
426
 
427
  # Mode selection
 
433
  )
434
 
435
  # Input components based on mode
436
+ pdb_input = gr.Textbox(value="2F6V", label="PDB ID", placeholder="Enter PDB ID here...")
437
  pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
438
+ visualize_btn = gr.Button("Visualize Structure", elem_id="visualize-btn")
439
 
440
  molecule_output2 = Molecule3D(label="Protein Structure", reps=[
441
  {
 
449
  ])
450
 
451
  with gr.Row():
452
+ segment_input = gr.Textbox(value="A", label="Chain ID (protein)", placeholder="Enter Chain ID here...",
453
+ info="Choose in which chain to predict binding sites.")
454
+ prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
455
 
456
  molecule_output = gr.HTML(label="Protein Structure")
457
  explanation_vis = gr.Markdown("""
 
525
  examples=[
526
  ["7RPZ", "A"],
527
  ["2IWI", "B"],
528
+ ["7LCJ", "R"]
529
  ],
530
  inputs=[pdb_input, segment_input],
531
  outputs=[predictions_output, molecule_output, download_output]
app.py.backup ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import gradio as gr
3
+ import requests
4
+ from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select
5
+ from Bio.PDB.Polypeptide import is_aa
6
+ from Bio.SeqUtils import seq1
7
+ from typing import Optional, Tuple
8
+ import numpy as np
9
+ import os
10
+ from gradio_molecule3d import Molecule3D
11
+
12
+ from model_loader import load_model
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import DataLoader
18
+
19
+ import re
20
+ import pandas as pd
21
+ import copy
22
+
23
+ import transformers
24
+ from transformers import AutoTokenizer, DataCollatorForTokenClassification
25
+
26
+ from datasets import Dataset
27
+
28
+ from scipy.special import expit
29
+
30
+
31
+ # Load model and move to device
32
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
+ max_length = 1500
34
+ model, tokenizer = load_model(checkpoint, max_length)
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+ model.to(device)
37
+ model.eval()
38
+
39
+ def normalize_scores(scores):
40
+ min_score = np.min(scores)
41
+ max_score = np.max(scores)
42
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
43
+
44
+ def read_mol(pdb_path):
45
+ """Read PDB file and return its content as a string"""
46
+ with open(pdb_path, 'r') as f:
47
+ return f.read()
48
+
49
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
50
+ """
51
+ Fetch the structure file for a given PDB ID. Prioritizes CIF files.
52
+ If a structure file already exists locally, it uses that.
53
+ """
54
+ file_path = download_structure(pdb_id, output_dir)
55
+ if file_path:
56
+ return file_path
57
+ else:
58
+ return None
59
+
60
+ def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:
61
+ """
62
+ Attempt to download the structure file in CIF or PDB format.
63
+ Returns the path to the downloaded file, or None if download fails.
64
+ """
65
+ for ext in ['.cif', '.pdb']:
66
+ file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
67
+ if os.path.exists(file_path):
68
+ return file_path
69
+ url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
70
+ try:
71
+ response = requests.get(url, timeout=10)
72
+ if response.status_code == 200:
73
+ with open(file_path, 'wb') as f:
74
+ f.write(response.content)
75
+ return file_path
76
+ except Exception as e:
77
+ print(f"Download error for {pdb_id}{ext}: {e}")
78
+ return None
79
+
80
+ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
81
+ """
82
+ Convert a CIF file to PDB format using BioPython and return the PDB file path.
83
+ """
84
+ pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))
85
+ parser = MMCIFParser(QUIET=True)
86
+ structure = parser.get_structure('protein', cif_path)
87
+ io = PDBIO()
88
+ io.set_structure(structure)
89
+ io.save(pdb_path)
90
+ return pdb_path
91
+
92
+ def fetch_pdb(pdb_id):
93
+ pdb_path = fetch_structure(pdb_id)
94
+ if not pdb_path:
95
+ return None
96
+ _, ext = os.path.splitext(pdb_path)
97
+ if ext == '.cif':
98
+ pdb_path = convert_cif_to_pdb(pdb_path)
99
+ return pdb_path
100
+
101
+ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:
102
+ """
103
+ Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
104
+ """
105
+ # Read the original PDB file
106
+ parser = PDBParser(QUIET=True)
107
+ structure = parser.get_structure('protein', input_pdb)
108
+
109
+ # Prepare a new structure with only the specified chain and selected residues
110
+ output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
111
+
112
+ # Create scores dictionary for easy lookup
113
+ scores_dict = {resi: score for resi, score in residue_scores}
114
+
115
+ # Create a custom Select class
116
+ class ResidueSelector(Select):
117
+ def __init__(self, chain_id, selected_residues, scores_dict):
118
+ self.chain_id = chain_id
119
+ self.selected_residues = selected_residues
120
+ self.scores_dict = scores_dict
121
+
122
+ def accept_chain(self, chain):
123
+ return chain.id == self.chain_id
124
+
125
+ def accept_residue(self, residue):
126
+ return residue.id[1] in self.selected_residues
127
+
128
+ def accept_atom(self, atom):
129
+ if atom.parent.id[1] in self.scores_dict:
130
+ atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100
131
+ return True
132
+
133
+ # Prepare output PDB with selected chain and residues, modified B-factors
134
+ io = PDBIO()
135
+ selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)
136
+
137
+ io.set_structure(structure[0])
138
+ io.save(output_pdb, selector)
139
+
140
+ return output_pdb
141
+
142
+ def calculate_geometric_center(pdb_path: str, high_score_residues: list, chain_id: str):
143
+ """
144
+ Calculate the geometric center of high-scoring residues
145
+ """
146
+ parser = PDBParser(QUIET=True)
147
+ structure = parser.get_structure('protein', pdb_path)
148
+
149
+ # Collect coordinates of CA atoms from high-scoring residues
150
+ coords = []
151
+ for model in structure:
152
+ for chain in model:
153
+ if chain.id == chain_id:
154
+ for residue in chain:
155
+ if residue.id[1] in high_score_residues:
156
+ if 'CA' in residue: # Use alpha carbon as representative
157
+ ca_atom = residue['CA']
158
+ coords.append(ca_atom.coord)
159
+
160
+ # Calculate geometric center
161
+ if coords:
162
+ center = np.mean(coords, axis=0)
163
+ return center
164
+ return None
165
+
166
+ def process_pdb(pdb_id_or_file, segment):
167
+ # Determine if input is a PDB ID or file path
168
+ if pdb_id_or_file.endswith('.pdb'):
169
+ pdb_path = pdb_id_or_file
170
+ pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]
171
+ else:
172
+ pdb_id = pdb_id_or_file
173
+ pdb_path = fetch_pdb(pdb_id)
174
+
175
+ if not pdb_path:
176
+ return "Failed to fetch PDB file", None, None
177
+
178
+ # Determine the file format and choose the appropriate parser
179
+ _, ext = os.path.splitext(pdb_path)
180
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
181
+
182
+ try:
183
+ # Parse the structure file
184
+ structure = parser.get_structure('protein', pdb_path)
185
+ except Exception as e:
186
+ return f"Error parsing structure file: {e}", None, None
187
+
188
+ # Extract the specified chain
189
+ try:
190
+ chain = structure[0][segment]
191
+ except KeyError:
192
+ return "Invalid Chain ID", None, None
193
+
194
+ protein_residues = [res for res in chain if is_aa(res)]
195
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
196
+ sequence_id = [res.id[1] for res in protein_residues]
197
+
198
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
199
+ with torch.no_grad():
200
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
201
+
202
+ # Calculate scores and normalize them
203
+ scores = expit(outputs[:, 1] - outputs[:, 0])
204
+
205
+ normalized_scores = normalize_scores(scores)
206
+
207
+ # Zip residues with scores to track the residue ID and score
208
+ residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
209
+
210
+
211
+ # Define the score brackets
212
+ score_brackets = {
213
+ "0.0-0.2": (0.0, 0.2),
214
+ "0.2-0.4": (0.2, 0.4),
215
+ "0.4-0.6": (0.4, 0.6),
216
+ "0.6-0.8": (0.6, 0.8),
217
+ "0.8-1.0": (0.8, 1.0)
218
+ }
219
+
220
+ # Initialize a dictionary to store residues by bracket
221
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
222
+
223
+ # Categorize residues into brackets
224
+ for resi, score in residue_scores:
225
+ for bracket, (lower, upper) in score_brackets.items():
226
+ if lower <= score < upper:
227
+ residues_by_bracket[bracket].append(resi)
228
+ break
229
+
230
+ # Preparing the result string
231
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
232
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
233
+ result_str += "Residues by Score Brackets:\n\n"
234
+
235
+ # Add residues for each bracket
236
+ for bracket, residues in residues_by_bracket.items():
237
+ result_str += f"Bracket {bracket}:\n"
238
+ result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n"
239
+ result_str += "\n".join([
240
+ f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
241
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
242
+ ])
243
+ result_str += "\n\n"
244
+
245
+ # Create chain-specific PDB with scores in B-factor
246
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
247
+
248
+ # Molecule visualization with updated script with color mapping
249
+ mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)
250
+
251
+ # Improved PyMOL command suggestions
252
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
253
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
254
+
255
+ pymol_commands += f"""
256
+ # PyMOL Visualization Commands
257
+ load {os.path.abspath(pdb_path)}, protein
258
+ hide everything, all
259
+ show cartoon, chain {segment}
260
+ color white, chain {segment}
261
+ """
262
+
263
+ # Define colors for each score bracket
264
+ bracket_colors = {
265
+ "0.0-0.2": "white",
266
+ "0.2-0.4": "lightorange",
267
+ "0.4-0.6": "orange",
268
+ "0.6-0.8": "orangered",
269
+ "0.8-1.0": "red"
270
+ }
271
+
272
+ # Add PyMOL commands for each score bracket
273
+ for bracket, residues in residues_by_bracket.items():
274
+ if residues: # Only add commands if there are residues in this bracket
275
+ color = bracket_colors[bracket]
276
+ resi_list = '+'.join(map(str, residues))
277
+ pymol_commands += f"""
278
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
279
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
280
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
281
+ """
282
+ # Create prediction and scored PDB files
283
+ prediction_file = f"{pdb_id}_binding_site_residues.txt"
284
+ with open(prediction_file, "w") as f:
285
+ f.write(result_str)
286
+
287
+ return pymol_commands, mol_vis, [prediction_file,scored_pdb]
288
+
289
+ def molecule(input_pdb, residue_scores=None, segment='A'):
290
+ # More granular scoring for visualization
291
+ mol = read_mol(input_pdb) # Read PDB file content
292
+
293
+ # Prepare high-scoring residues script if scores are provided
294
+ high_score_script = ""
295
+ if residue_scores is not None:
296
+ # Filter residues based on their scores
297
+ class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2]
298
+ class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4]
299
+ class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6]
300
+ class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
301
+ class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
302
+
303
+
304
+ high_score_script = """
305
+ // Load the original model and apply white cartoon style
306
+ let chainModel = viewer.addModel(pdb, "pdb");
307
+ chainModel.setStyle({}, {});
308
+ chainModel.setStyle(
309
+ {"chain": "%s"},
310
+ {"cartoon": {"color": "white"}}
311
+ );
312
+
313
+ // Create a new model for high-scoring residues and apply red sticks style
314
+ let class1Model = viewer.addModel(pdb, "pdb");
315
+ class1Model.setStyle({}, {});
316
+ class1Model.setStyle(
317
+ {"chain": "%s", "resi": [%s]},
318
+ {"stick": {"color": "0xFFFFFF", "opacity": 0.5}}
319
+ );
320
+
321
+ // Create a new model for high-scoring residues and apply red sticks style
322
+ let class2Model = viewer.addModel(pdb, "pdb");
323
+ class2Model.setStyle({}, {});
324
+ class2Model.setStyle(
325
+ {"chain": "%s", "resi": [%s]},
326
+ {"stick": {"color": "0xFFD580", "opacity": 0.7}}
327
+ );
328
+
329
+ // Create a new model for high-scoring residues and apply red sticks style
330
+ let class3Model = viewer.addModel(pdb, "pdb");
331
+ class3Model.setStyle({}, {});
332
+ class3Model.setStyle(
333
+ {"chain": "%s", "resi": [%s]},
334
+ {"stick": {"color": "0xFFA500", "opacity": 1}}
335
+ );
336
+
337
+ // Create a new model for high-scoring residues and apply red sticks style
338
+ let class4Model = viewer.addModel(pdb, "pdb");
339
+ class4Model.setStyle({}, {});
340
+ class4Model.setStyle(
341
+ {"chain": "%s", "resi": [%s]},
342
+ {"stick": {"color": "0xFF4500", "opacity": 1}}
343
+ );
344
+
345
+ // Create a new model for high-scoring residues and apply red sticks style
346
+ let class5Model = viewer.addModel(pdb, "pdb");
347
+ class5Model.setStyle({}, {});
348
+ class5Model.setStyle(
349
+ {"chain": "%s", "resi": [%s]},
350
+ {"stick": {"color": "0xFF0000", "alpha": 1}}
351
+ );
352
+
353
+ """ % (
354
+ segment,
355
+ segment,
356
+ ", ".join(str(resi) for resi in class1_score_residues),
357
+ segment,
358
+ ", ".join(str(resi) for resi in class2_score_residues),
359
+ segment,
360
+ ", ".join(str(resi) for resi in class3_score_residues),
361
+ segment,
362
+ ", ".join(str(resi) for resi in class4_score_residues),
363
+ segment,
364
+ ", ".join(str(resi) for resi in class5_score_residues)
365
+ )
366
+
367
+ # Generate the full HTML content
368
+ html_content = f"""
369
+ <!DOCTYPE html>
370
+ <html>
371
+ <head>
372
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
373
+ <style>
374
+ .mol-container {{
375
+ width: 100%;
376
+ height: 700px;
377
+ position: relative;
378
+ }}
379
+ </style>
380
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script>
381
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
382
+ </head>
383
+ <body>
384
+ <div id="container" class="mol-container"></div>
385
+ <script>
386
+ let pdb = `{mol}`; // Use template literal to properly escape PDB content
387
+ $(document).ready(function () {{
388
+ let element = $("#container");
389
+ let config = {{ backgroundColor: "white" }};
390
+ let viewer = $3Dmol.createViewer(element, config);
391
+
392
+ {high_score_script}
393
+
394
+ // Add hover functionality
395
+ viewer.setHoverable(
396
+ {{}},
397
+ true,
398
+ function(atom, viewer, event, container) {{
399
+ if (!atom.label) {{
400
+ atom.label = viewer.addLabel(
401
+ atom.resn + ":" +atom.resi + ":" + atom.atom,
402
+ {{
403
+ position: atom,
404
+ backgroundColor: 'mintcream',
405
+ fontColor: 'black',
406
+ fontSize: 18,
407
+ padding: 4
408
+ }}
409
+ );
410
+ }}
411
+ }},
412
+ function(atom, viewer) {{
413
+ if (atom.label) {{
414
+ viewer.removeLabel(atom.label);
415
+ delete atom.label;
416
+ }}
417
+ }}
418
+ );
419
+
420
+ viewer.zoomTo();
421
+ viewer.render();
422
+ viewer.zoom(0.8, 2000);
423
+ }});
424
+ </script>
425
+ </body>
426
+ </html>
427
+ """
428
+
429
+ # Return the HTML content within an iframe safely encoded for special characters
430
+ return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
431
+
432
+ # Gradio UI
433
+ with gr.Blocks() as demo:
434
+ gr.Markdown("# Protein Binding Site Prediction")
435
+
436
+ # Mode selection
437
+ mode = gr.Radio(
438
+ choices=["PDB ID", "Upload File"],
439
+ value="PDB ID",
440
+ label="Input Mode",
441
+ info="Choose whether to input a PDB ID or upload a PDB/CIF file."
442
+ )
443
+
444
+ # Input components based on mode
445
+ pdb_input = gr.Textbox(value="4BDU", label="PDB ID", placeholder="Enter PDB ID here...")
446
+ pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
447
+ visualize_btn = gr.Button("Visualize Structure")
448
+
449
+ molecule_output2 = Molecule3D(label="Protein Structure", reps=[
450
+ {
451
+ "model": 0,
452
+ "style": "cartoon",
453
+ "color": "whiteCarbon",
454
+ "residue_range": "",
455
+ "around": 0,
456
+ "byres": False,
457
+ }
458
+ ])
459
+
460
+ with gr.Row():
461
+ segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
462
+ prediction_btn = gr.Button("Predict Binding Site")
463
+
464
+ molecule_output = gr.HTML(label="Protein Structure")
465
+ explanation_vis = gr.Markdown("""
466
+ Score dependent colorcoding:
467
+ - 0.0-0.2: white
468
+ - 0.2–0.4: light orange
469
+ - 0.4–0.6: orange
470
+ - 0.6–0.8: orangered
471
+ - 0.8–1.0: red
472
+ """)
473
+ predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
474
+ gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
475
+ download_output = gr.File(label="Download Files", file_count="multiple")
476
+
477
+ def process_interface(mode, pdb_id, pdb_file, chain_id):
478
+ if mode == "PDB ID":
479
+ return process_pdb(pdb_id, chain_id)
480
+ elif mode == "Upload File":
481
+ _, ext = os.path.splitext(pdb_file.name)
482
+ file_path = os.path.join('./', f"{_}{ext}")
483
+ if ext == '.cif':
484
+ pdb_path = convert_cif_to_pdb(file_path)
485
+ else:
486
+ pdb_path= file_path
487
+ return process_pdb(pdb_path, chain_id)
488
+ else:
489
+ return "Error: Invalid mode selected", None, None
490
+
491
+ def fetch_interface(mode, pdb_id, pdb_file):
492
+ if mode == "PDB ID":
493
+ return fetch_pdb(pdb_id)
494
+ elif mode == "Upload File":
495
+ _, ext = os.path.splitext(pdb_file.name)
496
+ file_path = os.path.join('./', f"{_}{ext}")
497
+ #print(ext)
498
+ if ext == '.cif':
499
+ pdb_path = convert_cif_to_pdb(file_path)
500
+ else:
501
+ pdb_path= file_path
502
+ #print(pdb_path)
503
+ return pdb_path
504
+ else:
505
+ return "Error: Invalid mode selected"
506
+
507
+ def toggle_mode(selected_mode):
508
+ if selected_mode == "PDB ID":
509
+ return gr.update(visible=True), gr.update(visible=False)
510
+ else:
511
+ return gr.update(visible=False), gr.update(visible=True)
512
+
513
+ mode.change(
514
+ toggle_mode,
515
+ inputs=[mode],
516
+ outputs=[pdb_input, pdb_file]
517
+ )
518
+
519
+ prediction_btn.click(
520
+ process_interface,
521
+ inputs=[mode, pdb_input, pdb_file, segment_input],
522
+ outputs=[predictions_output, molecule_output, download_output]
523
+ )
524
+
525
+ visualize_btn.click(
526
+ fetch_interface,
527
+ inputs=[mode, pdb_input, pdb_file],
528
+ outputs=molecule_output2
529
+ )
530
+
531
+ gr.Markdown("## Examples")
532
+ gr.Examples(
533
+ examples=[
534
+ ["7RPZ", "A"],
535
+ ["2IWI", "B"],
536
+ ["2F6V", "A"]
537
+ ],
538
+ inputs=[pdb_input, segment_input],
539
+ outputs=[predictions_output, molecule_output, download_output]
540
+ )
541
+
542
+ demo.launch(share=True)
test.ipynb ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 29,
6
+ "id": "e776d9d6-417e-46d4-8061-846c055e1f8a",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7873\n",
14
+ "* Running on public URL: https://120000a6aa9d78e04c.gradio.live\n",
15
+ "\n",
16
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
17
+ ]
18
+ },
19
+ {
20
+ "data": {
21
+ "text/html": [
22
+ "<div><iframe src=\"https://120000a6aa9d78e04c.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
23
+ ],
24
+ "text/plain": [
25
+ "<IPython.core.display.HTML object>"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "data": {
33
+ "text/plain": []
34
+ },
35
+ "execution_count": 29,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
+ "source": [
41
+ "from datetime import datetime\n",
42
+ "import gradio as gr\n",
43
+ "import requests\n",
44
+ "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select\n",
45
+ "from Bio.PDB.Polypeptide import is_aa\n",
46
+ "from Bio.SeqUtils import seq1\n",
47
+ "from typing import Optional, Tuple\n",
48
+ "import numpy as np\n",
49
+ "import os\n",
50
+ "from gradio_molecule3d import Molecule3D\n",
51
+ "\n",
52
+ "#from model_loader import load_model\n",
53
+ "\n",
54
+ "import torch\n",
55
+ "import torch.nn as nn\n",
56
+ "import torch.nn.functional as F\n",
57
+ "from torch.utils.data import DataLoader\n",
58
+ "\n",
59
+ "import re\n",
60
+ "import pandas as pd\n",
61
+ "import copy\n",
62
+ "\n",
63
+ "#import transformers\n",
64
+ "#from transformers import AutoTokenizer, DataCollatorForTokenClassification\n",
65
+ "\n",
66
+ "#from datasets import Dataset\n",
67
+ "\n",
68
+ "from scipy.special import expit\n",
69
+ "\n",
70
+ "def normalize_scores(scores):\n",
71
+ " min_score = np.min(scores)\n",
72
+ " max_score = np.max(scores)\n",
73
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
74
+ "\n",
75
+ "def read_mol(pdb_path):\n",
76
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
77
+ " with open(pdb_path, 'r') as f:\n",
78
+ " return f.read()\n",
79
+ "\n",
80
+ "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
81
+ " \"\"\"\n",
82
+ " Fetch the structure file for a given PDB ID. Prioritizes CIF files.\n",
83
+ " If a structure file already exists locally, it uses that.\n",
84
+ " \"\"\"\n",
85
+ " file_path = download_structure(pdb_id, output_dir)\n",
86
+ " if file_path:\n",
87
+ " return file_path\n",
88
+ " else:\n",
89
+ " return None\n",
90
+ "\n",
91
+ "def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:\n",
92
+ " \"\"\"\n",
93
+ " Attempt to download the structure file in CIF or PDB format.\n",
94
+ " Returns the path to the downloaded file, or None if download fails.\n",
95
+ " \"\"\"\n",
96
+ " for ext in ['.cif', '.pdb']:\n",
97
+ " file_path = os.path.join(output_dir, f\"{pdb_id}{ext}\")\n",
98
+ " if os.path.exists(file_path):\n",
99
+ " return file_path\n",
100
+ " url = f\"https://files.rcsb.org/download/{pdb_id}{ext}\"\n",
101
+ " try:\n",
102
+ " response = requests.get(url, timeout=10)\n",
103
+ " if response.status_code == 200:\n",
104
+ " with open(file_path, 'wb') as f:\n",
105
+ " f.write(response.content)\n",
106
+ " return file_path\n",
107
+ " except Exception as e:\n",
108
+ " print(f\"Download error for {pdb_id}{ext}: {e}\")\n",
109
+ " return None\n",
110
+ "\n",
111
+ "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
112
+ " \"\"\"\n",
113
+ " Convert a CIF file to PDB format using BioPython and return the PDB file path.\n",
114
+ " \"\"\"\n",
115
+ " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
116
+ " parser = MMCIFParser(QUIET=True)\n",
117
+ " structure = parser.get_structure('protein', cif_path)\n",
118
+ " io = PDBIO()\n",
119
+ " io.set_structure(structure)\n",
120
+ " io.save(pdb_path)\n",
121
+ " return pdb_path\n",
122
+ "\n",
123
+ "def fetch_pdb(pdb_id):\n",
124
+ " pdb_path = fetch_structure(pdb_id)\n",
125
+ " if not pdb_path:\n",
126
+ " return None\n",
127
+ " _, ext = os.path.splitext(pdb_path)\n",
128
+ " if ext == '.cif':\n",
129
+ " pdb_path = convert_cif_to_pdb(pdb_path)\n",
130
+ " return pdb_path\n",
131
+ "\n",
132
+ "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
133
+ " \"\"\"\n",
134
+ " Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores\n",
135
+ " \"\"\"\n",
136
+ " # Read the original PDB file\n",
137
+ " parser = PDBParser(QUIET=True)\n",
138
+ " structure = parser.get_structure('protein', input_pdb)\n",
139
+ " \n",
140
+ " # Prepare a new structure with only the specified chain and selected residues\n",
141
+ " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
142
+ " \n",
143
+ " # Create scores dictionary for easy lookup\n",
144
+ " scores_dict = {resi: score for resi, score in residue_scores}\n",
145
+ "\n",
146
+ " # Create a custom Select class\n",
147
+ " class ResidueSelector(Select):\n",
148
+ " def __init__(self, chain_id, selected_residues, scores_dict):\n",
149
+ " self.chain_id = chain_id\n",
150
+ " self.selected_residues = selected_residues\n",
151
+ " self.scores_dict = scores_dict\n",
152
+ " \n",
153
+ " def accept_chain(self, chain):\n",
154
+ " return chain.id == self.chain_id\n",
155
+ " \n",
156
+ " def accept_residue(self, residue):\n",
157
+ " return residue.id[1] in self.selected_residues\n",
158
+ "\n",
159
+ " def accept_atom(self, atom):\n",
160
+ " if atom.parent.id[1] in self.scores_dict:\n",
161
+ " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n",
162
+ " return True\n",
163
+ "\n",
164
+ " # Prepare output PDB with selected chain and residues, modified B-factors\n",
165
+ " io = PDBIO()\n",
166
+ " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
167
+ " \n",
168
+ " io.set_structure(structure[0])\n",
169
+ " io.save(output_pdb, selector)\n",
170
+ " \n",
171
+ " return output_pdb\n",
172
+ "\n",
173
+ "def process_pdb(pdb_id_or_file, segment):\n",
174
+ " # Determine if input is a PDB ID or file path\n",
175
+ " if pdb_id_or_file.endswith('.pdb'):\n",
176
+ " pdb_path = pdb_id_or_file\n",
177
+ " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
178
+ " else:\n",
179
+ " pdb_id = pdb_id_or_file\n",
180
+ " pdb_path = fetch_pdb(pdb_id)\n",
181
+ " \n",
182
+ " if not pdb_path:\n",
183
+ " return \"Failed to fetch PDB file\", None, None\n",
184
+ " \n",
185
+ " # Determine the file format and choose the appropriate parser\n",
186
+ " _, ext = os.path.splitext(pdb_path)\n",
187
+ " parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)\n",
188
+ " \n",
189
+ " try:\n",
190
+ " # Parse the structure file\n",
191
+ " structure = parser.get_structure('protein', pdb_path)\n",
192
+ " except Exception as e:\n",
193
+ " return f\"Error parsing structure file: {e}\", None, None\n",
194
+ " \n",
195
+ " # Extract the specified chain\n",
196
+ " try:\n",
197
+ " chain = structure[0][segment]\n",
198
+ " except KeyError:\n",
199
+ " return \"Invalid Chain ID\", None, None\n",
200
+ " \n",
201
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
202
+ " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
203
+ " sequence_id = [res.id[1] for res in protein_residues]\n",
204
+ "\n",
205
+ " visualized_sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
206
+ " if sequence != visualized_sequence:\n",
207
+ " raise ValueError(\"The visualized sequence does not match the prediction sequence\")\n",
208
+ " \n",
209
+ " scores = np.random.rand(len(sequence))\n",
210
+ " normalized_scores = normalize_scores(scores)\n",
211
+ " \n",
212
+ " # Zip residues with scores to track the residue ID and score\n",
213
+ " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
214
+ "\n",
215
+ " \n",
216
+ " # Define the score brackets\n",
217
+ " score_brackets = {\n",
218
+ " \"0.0-0.2\": (0.0, 0.2),\n",
219
+ " \"0.2-0.4\": (0.2, 0.4),\n",
220
+ " \"0.4-0.6\": (0.4, 0.6),\n",
221
+ " \"0.6-0.8\": (0.6, 0.8),\n",
222
+ " \"0.8-1.0\": (0.8, 1.0)\n",
223
+ " }\n",
224
+ " \n",
225
+ " # Initialize a dictionary to store residues by bracket\n",
226
+ " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n",
227
+ " \n",
228
+ " # Categorize residues into brackets\n",
229
+ " for resi, score in residue_scores:\n",
230
+ " for bracket, (lower, upper) in score_brackets.items():\n",
231
+ " if lower <= score < upper:\n",
232
+ " residues_by_bracket[bracket].append(resi)\n",
233
+ " break\n",
234
+ " \n",
235
+ " # Preparing the result string\n",
236
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
237
+ " result_str = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
238
+ " result_str += \"Residues by Score Brackets:\\n\\n\"\n",
239
+ " \n",
240
+ " # Add residues for each bracket\n",
241
+ " for bracket, residues in residues_by_bracket.items():\n",
242
+ " result_str += f\"Bracket {bracket}:\\n\"\n",
243
+ " result_str += \"Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\\n\"\n",
244
+ " result_str += \"\\n\".join([\n",
245
+ " f\"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
246
+ " for i, res in enumerate(protein_residues) if res.id[1] in residues\n",
247
+ " ])\n",
248
+ " result_str += \"\\n\\n\"\n",
249
+ "\n",
250
+ " # Create chain-specific PDB with scores in B-factor\n",
251
+ " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
252
+ "\n",
253
+ " # Molecule visualization with updated script with color mapping\n",
254
+ " mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)\n",
255
+ "\n",
256
+ " # Improved PyMOL command suggestions\n",
257
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
258
+ " pymol_commands = f\"Prediction for PDB: {pdb_id}, Chain: {segment}\\nDate: {current_time}\\n\\n\"\n",
259
+ " \n",
260
+ " pymol_commands += f\"\"\"\n",
261
+ " # PyMOL Visualization Commands\n",
262
+ " load {os.path.abspath(pdb_path)}, protein\n",
263
+ " hide everything, all\n",
264
+ " show cartoon, chain {segment}\n",
265
+ " color white, chain {segment}\n",
266
+ " \"\"\"\n",
267
+ " \n",
268
+ " # Define colors for each score bracket\n",
269
+ " bracket_colors = {\n",
270
+ " \"0.0-0.2\": \"white\",\n",
271
+ " \"0.2-0.4\": \"lightorange\",\n",
272
+ " \"0.4-0.6\": \"orange\",\n",
273
+ " \"0.6-0.8\": \"orangered\",\n",
274
+ " \"0.8-1.0\": \"red\"\n",
275
+ " }\n",
276
+ " \n",
277
+ " # Add PyMOL commands for each score bracket\n",
278
+ " for bracket, residues in residues_by_bracket.items():\n",
279
+ " if residues: # Only add commands if there are residues in this bracket\n",
280
+ " color = bracket_colors[bracket]\n",
281
+ " resi_list = '+'.join(map(str, residues))\n",
282
+ " pymol_commands += f\"\"\"\n",
283
+ " select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}\n",
284
+ " show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}\n",
285
+ " color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}\n",
286
+ " \"\"\"\n",
287
+ " \n",
288
+ " # Create prediction and scored PDB files\n",
289
+ " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
290
+ " with open(prediction_file, \"w\") as f:\n",
291
+ " f.write(result_str)\n",
292
+ " \n",
293
+ " return pymol_commands, mol_vis, [prediction_file,scored_pdb]\n",
294
+ "\n",
295
+ "def molecule(input_pdb, residue_scores=None, segment='A'):\n",
296
+ " # More granular scoring for visualization\n",
297
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
298
+ "\n",
299
+ " # Prepare high-scoring residues script if scores are provided\n",
300
+ " high_score_script = \"\"\n",
301
+ " if residue_scores is not None:\n",
302
+ " # Filter residues based on their scores\n",
303
+ " class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2]\n",
304
+ " class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4]\n",
305
+ " class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6]\n",
306
+ " class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]\n",
307
+ " class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]\n",
308
+ " \n",
309
+ " high_score_script = \"\"\"\n",
310
+ " // Load the original model and apply white cartoon style\n",
311
+ " let chainModel = viewer.addModel(pdb, \"pdb\");\n",
312
+ " chainModel.setStyle({}, {});\n",
313
+ " chainModel.setStyle(\n",
314
+ " {\"chain\": \"%s\"}, \n",
315
+ " {\"cartoon\": {\"color\": \"white\"}}\n",
316
+ " );\n",
317
+ "\n",
318
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
319
+ " let class1Model = viewer.addModel(pdb, \"pdb\");\n",
320
+ " class1Model.setStyle({}, {});\n",
321
+ " class1Model.setStyle(\n",
322
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
323
+ " {\"stick\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}}\n",
324
+ " );\n",
325
+ "\n",
326
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
327
+ " let class2Model = viewer.addModel(pdb, \"pdb\");\n",
328
+ " class2Model.setStyle({}, {});\n",
329
+ " class2Model.setStyle(\n",
330
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
331
+ " {\"stick\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}}\n",
332
+ " );\n",
333
+ "\n",
334
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
335
+ " let class3Model = viewer.addModel(pdb, \"pdb\");\n",
336
+ " class3Model.setStyle({}, {});\n",
337
+ " class3Model.setStyle(\n",
338
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
339
+ " {\"stick\": {\"color\": \"0xFFA500\", \"opacity\": 1}}\n",
340
+ " );\n",
341
+ "\n",
342
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
343
+ " let class4Model = viewer.addModel(pdb, \"pdb\");\n",
344
+ " class4Model.setStyle({}, {});\n",
345
+ " class4Model.setStyle(\n",
346
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
347
+ " {\"stick\": {\"color\": \"0xFF4500\", \"opacity\": 1}}\n",
348
+ " );\n",
349
+ "\n",
350
+ " // Create a new model for high-scoring residues and apply red sticks style\n",
351
+ " let class5Model = viewer.addModel(pdb, \"pdb\");\n",
352
+ " class5Model.setStyle({}, {});\n",
353
+ " class5Model.setStyle(\n",
354
+ " {\"chain\": \"%s\", \"resi\": [%s]}, \n",
355
+ " {\"stick\": {\"color\": \"0xFF0000\", \"alpha\": 1}}\n",
356
+ " );\n",
357
+ "\n",
358
+ " \"\"\" % (\n",
359
+ " segment,\n",
360
+ " segment,\n",
361
+ " \", \".join(str(resi) for resi in class1_score_residues),\n",
362
+ " segment,\n",
363
+ " \", \".join(str(resi) for resi in class2_score_residues),\n",
364
+ " segment,\n",
365
+ " \", \".join(str(resi) for resi in class3_score_residues),\n",
366
+ " segment,\n",
367
+ " \", \".join(str(resi) for resi in class4_score_residues),\n",
368
+ " segment,\n",
369
+ " \", \".join(str(resi) for resi in class5_score_residues)\n",
370
+ " )\n",
371
+ " \n",
372
+ " # Generate the full HTML content\n",
373
+ " html_content = f\"\"\"\n",
374
+ " <!DOCTYPE html>\n",
375
+ " <html>\n",
376
+ " <head> \n",
377
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
378
+ " <style>\n",
379
+ " .mol-container {{\n",
380
+ " width: 100%;\n",
381
+ " height: 700px;\n",
382
+ " position: relative;\n",
383
+ " }}\n",
384
+ " </style>\n",
385
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
386
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
387
+ " </head>\n",
388
+ " <body>\n",
389
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
390
+ " <script>\n",
391
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
392
+ " $(document).ready(function () {{\n",
393
+ " let element = $(\"#container\");\n",
394
+ " let config = {{ backgroundColor: \"white\" }};\n",
395
+ " let viewer = $3Dmol.createViewer(element, config);\n",
396
+ " \n",
397
+ " {high_score_script}\n",
398
+ " \n",
399
+ " // Add hover functionality\n",
400
+ " viewer.setHoverable(\n",
401
+ " {{}}, \n",
402
+ " true, \n",
403
+ " function(atom, viewer, event, container) {{\n",
404
+ " if (!atom.label) {{\n",
405
+ " atom.label = viewer.addLabel(\n",
406
+ " atom.resn + \":\" +atom.resi + \":\" + atom.atom, \n",
407
+ " {{\n",
408
+ " position: atom, \n",
409
+ " backgroundColor: 'mintcream', \n",
410
+ " fontColor: 'black',\n",
411
+ " fontSize: 18,\n",
412
+ " padding: 4\n",
413
+ " }}\n",
414
+ " );\n",
415
+ " }}\n",
416
+ " }},\n",
417
+ " function(atom, viewer) {{\n",
418
+ " if (atom.label) {{\n",
419
+ " viewer.removeLabel(atom.label);\n",
420
+ " delete atom.label;\n",
421
+ " }}\n",
422
+ " }}\n",
423
+ " );\n",
424
+ " \n",
425
+ " viewer.zoomTo();\n",
426
+ " viewer.render();\n",
427
+ " viewer.zoom(0.8, 2000);\n",
428
+ " }});\n",
429
+ " </script>\n",
430
+ " </body>\n",
431
+ " </html>\n",
432
+ " \"\"\"\n",
433
+ " \n",
434
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
435
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
436
+ "\n",
437
+ "# Gradio UI\n",
438
+ "with gr.Blocks(css=\"\"\"\n",
439
+ " /* Customize Gradio button colors */\n",
440
+ " #visualize-btn, #predict-btn {\n",
441
+ " background-color: #FF7300; /* Deep orange */\n",
442
+ " color: white;\n",
443
+ " border-radius: 5px;\n",
444
+ " padding: 10px;\n",
445
+ " font-weight: bold;\n",
446
+ " }\n",
447
+ " #visualize-btn:hover, #predict-btn:hover {\n",
448
+ " background-color: #CC5C00; /* Darkened orange on hover */\n",
449
+ " }\n",
450
+ "\"\"\") as demo:\n",
451
+ " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
452
+ " \n",
453
+ " # Mode selection\n",
454
+ " mode = gr.Radio(\n",
455
+ " choices=[\"PDB ID\", \"Upload File\"],\n",
456
+ " value=\"PDB ID\",\n",
457
+ " label=\"Input Mode\",\n",
458
+ " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file.\"\n",
459
+ " )\n",
460
+ "\n",
461
+ " # Input components based on mode\n",
462
+ " pdb_input = gr.Textbox(value=\"2F6V\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
463
+ " pdb_file = gr.File(label=\"Upload PDB/CIF File\", visible=False)\n",
464
+ " visualize_btn = gr.Button(\"Visualize Structure\", elem_id=\"visualize-btn\")\n",
465
+ "\n",
466
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=[\n",
467
+ " {\n",
468
+ " \"model\": 0,\n",
469
+ " \"style\": \"cartoon\",\n",
470
+ " \"color\": \"whiteCarbon\",\n",
471
+ " \"residue_range\": \"\",\n",
472
+ " \"around\": 0,\n",
473
+ " \"byres\": False,\n",
474
+ " }\n",
475
+ " ])\n",
476
+ "\n",
477
+ " with gr.Row():\n",
478
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID (protein)\", placeholder=\"Enter Chain ID here...\",\n",
479
+ " info=\"Choose in which chain to predict binding sites.\")\n",
480
+ " prediction_btn = gr.Button(\"Predict Binding Site\", elem_id=\"predict-btn\")\n",
481
+ "\n",
482
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
483
+ " explanation_vis = gr.Markdown(\"\"\"\n",
484
+ " Score dependent colorcoding:\n",
485
+ " - 0.0-0.2: white \n",
486
+ " - 0.2–0.4: light orange \n",
487
+ " - 0.4–0.6: orange\n",
488
+ " - 0.6–0.8: orangered\n",
489
+ " - 0.8–1.0: red\n",
490
+ " \"\"\")\n",
491
+ " predictions_output = gr.Textbox(label=\"Visualize Prediction with PyMol\")\n",
492
+ " gr.Markdown(\"### Download:\\n- List of predicted binding site residues\\n- PDB with score in beta factor column\")\n",
493
+ " download_output = gr.File(label=\"Download Files\", file_count=\"multiple\")\n",
494
+ " \n",
495
+ " def process_interface(mode, pdb_id, pdb_file, chain_id):\n",
496
+ " if mode == \"PDB ID\":\n",
497
+ " return process_pdb(pdb_id, chain_id)\n",
498
+ " elif mode == \"Upload File\":\n",
499
+ " _, ext = os.path.splitext(pdb_file.name)\n",
500
+ " file_path = os.path.join('./', f\"{_}{ext}\")\n",
501
+ " if ext == '.cif':\n",
502
+ " pdb_path = convert_cif_to_pdb(file_path)\n",
503
+ " else:\n",
504
+ " pdb_path= file_path\n",
505
+ " return process_pdb(pdb_path, chain_id)\n",
506
+ " else:\n",
507
+ " return \"Error: Invalid mode selected\", None, None\n",
508
+ "\n",
509
+ " def fetch_interface(mode, pdb_id, pdb_file):\n",
510
+ " if mode == \"PDB ID\":\n",
511
+ " return fetch_pdb(pdb_id)\n",
512
+ " elif mode == \"Upload File\":\n",
513
+ " _, ext = os.path.splitext(pdb_file.name)\n",
514
+ " file_path = os.path.join('./', f\"{_}{ext}\")\n",
515
+ " #print(ext)\n",
516
+ " if ext == '.cif':\n",
517
+ " pdb_path = convert_cif_to_pdb(file_path)\n",
518
+ " else:\n",
519
+ " pdb_path= file_path\n",
520
+ " #print(pdb_path)\n",
521
+ " return pdb_path\n",
522
+ " else:\n",
523
+ " return \"Error: Invalid mode selected\"\n",
524
+ "\n",
525
+ " def toggle_mode(selected_mode):\n",
526
+ " if selected_mode == \"PDB ID\":\n",
527
+ " return gr.update(visible=True), gr.update(visible=False)\n",
528
+ " else:\n",
529
+ " return gr.update(visible=False), gr.update(visible=True)\n",
530
+ "\n",
531
+ " mode.change(\n",
532
+ " toggle_mode,\n",
533
+ " inputs=[mode],\n",
534
+ " outputs=[pdb_input, pdb_file]\n",
535
+ " )\n",
536
+ "\n",
537
+ " prediction_btn.click(\n",
538
+ " process_interface, \n",
539
+ " inputs=[mode, pdb_input, pdb_file, segment_input], \n",
540
+ " outputs=[predictions_output, molecule_output, download_output]\n",
541
+ " )\n",
542
+ "\n",
543
+ " visualize_btn.click(\n",
544
+ " fetch_interface, \n",
545
+ " inputs=[mode, pdb_input, pdb_file], \n",
546
+ " outputs=molecule_output2\n",
547
+ " )\n",
548
+ "\n",
549
+ " gr.Markdown(\"## Examples\")\n",
550
+ " gr.Examples(\n",
551
+ " examples=[\n",
552
+ " [\"7RPZ\", \"A\"],\n",
553
+ " [\"2IWI\", \"B\"],\n",
554
+ " [\"7LCJ\", \"R\"]\n",
555
+ " ],\n",
556
+ " inputs=[pdb_input, segment_input],\n",
557
+ " outputs=[predictions_output, molecule_output, download_output]\n",
558
+ " )\n",
559
+ "\n",
560
+ "demo.launch(share=True)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "440c87ed-45c9-4501-b208-409cbfd7858b",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": []
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": 21,
574
+ "id": "d70c40b9-5d5a-4795-b2a2-149c4a57d16e",
575
+ "metadata": {},
576
+ "outputs": [
577
+ {
578
+ "name": "stderr",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py:441: UserWarning: Examples will be cached but not all input components have example values. This may result in an exception being thrown by your function. If you do get an error while caching examples, make sure all of your inputs have example values for all of your examples or you provide default values for those particular parameters in your function.\n",
582
+ " warnings.warn(\n",
583
+ "INFO:__main__:Using cached structure: ./7rpz.cif\n",
584
+ "INFO:__main__:Using cached structure: ./2iwi.cif\n",
585
+ "INFO:__main__:Using cached structure: ./2f6v.cif\n",
586
+ "INFO:httpx:HTTP Request: GET http://127.0.0.1:7862/gradio_api/startup-events \"HTTP/1.1 200 OK\"\n"
587
+ ]
588
+ },
589
+ {
590
+ "name": "stdout",
591
+ "output_type": "stream",
592
+ "text": [
593
+ "* Running on local URL: http://127.0.0.1:7862\n",
594
+ "Caching examples at: '/home/frohlkin/Projects/LargeLanguageModels/Publication/test_webpage/.gradio/cached_examples/148'\n"
595
+ ]
596
+ },
597
+ {
598
+ "name": "stderr",
599
+ "output_type": "stream",
600
+ "text": [
601
+ "INFO:httpx:HTTP Request: HEAD http://127.0.0.1:7862/ \"HTTP/1.1 200 OK\"\n",
602
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/pkg-version \"HTTP/1.1 200 OK\"\n",
603
+ "INFO:httpx:HTTP Request: GET https://api.gradio.app/v3/tunnel-request \"HTTP/1.1 200 OK\"\n"
604
+ ]
605
+ },
606
+ {
607
+ "name": "stdout",
608
+ "output_type": "stream",
609
+ "text": [
610
+ "* Running on public URL: https://de785d7cce806497e9.gradio.live\n",
611
+ "\n",
612
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
613
+ ]
614
+ },
615
+ {
616
+ "name": "stderr",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "INFO:httpx:HTTP Request: HEAD https://de785d7cce806497e9.gradio.live \"HTTP/1.1 200 OK\"\n"
620
+ ]
621
+ },
622
+ {
623
+ "data": {
624
+ "text/html": [
625
+ "<div><iframe src=\"https://de785d7cce806497e9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
626
+ ],
627
+ "text/plain": [
628
+ "<IPython.core.display.HTML object>"
629
+ ]
630
+ },
631
+ "metadata": {},
632
+ "output_type": "display_data"
633
+ },
634
+ {
635
+ "name": "stderr",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "Traceback (most recent call last):\n",
639
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
640
+ " output = await route_utils.call_process_api(\n",
641
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
642
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
643
+ " output = await app.get_blocks().process_api(\n",
644
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
645
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
646
+ " result = await self.call_function(\n",
647
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
648
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
649
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
650
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
651
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
652
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
653
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
654
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
655
+ " return await future\n",
656
+ " ^^^^^^^^^^^^\n",
657
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
658
+ " result = context.run(func, *args)\n",
659
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
660
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
661
+ " response = f(*args, **kwargs)\n",
662
+ " ^^^^^^^^^^^^^^^^^^\n",
663
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
664
+ " ) + self.load_from_cache(example_id)\n",
665
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
666
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
667
+ " output.append(component.read_from_flag(value_to_use))\n",
668
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
669
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
670
+ " return self.data_model.from_json(json.loads(payload))\n",
671
+ " ^^^^^^^^^^^^^^^^^^^\n",
672
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
673
+ " return _default_decoder.decode(s)\n",
674
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
675
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
676
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
677
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
678
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
679
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
680
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n",
681
+ "Traceback (most recent call last):\n",
682
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
683
+ " output = await route_utils.call_process_api(\n",
684
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
685
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
686
+ " output = await app.get_blocks().process_api(\n",
687
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
688
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
689
+ " result = await self.call_function(\n",
690
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
691
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
692
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
693
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
694
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
695
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
696
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
697
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
698
+ " return await future\n",
699
+ " ^^^^^^^^^^^^\n",
700
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
701
+ " result = context.run(func, *args)\n",
702
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
703
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
704
+ " response = f(*args, **kwargs)\n",
705
+ " ^^^^^^^^^^^^^^^^^^\n",
706
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
707
+ " ) + self.load_from_cache(example_id)\n",
708
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
709
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
710
+ " output.append(component.read_from_flag(value_to_use))\n",
711
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
712
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
713
+ " return self.data_model.from_json(json.loads(payload))\n",
714
+ " ^^^^^^^^^^^^^^^^^^^\n",
715
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
716
+ " return _default_decoder.decode(s)\n",
717
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
718
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
719
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
720
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
721
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
722
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
723
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n",
724
+ "Traceback (most recent call last):\n",
725
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/routes.py\", line 990, in predict\n",
726
+ " output = await route_utils.call_process_api(\n",
727
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
728
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
729
+ " output = await app.get_blocks().process_api(\n",
730
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
731
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2047, in process_api\n",
732
+ " result = await self.call_function(\n",
733
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
734
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1594, in call_function\n",
735
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
736
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
737
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
738
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
739
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
740
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2405, in run_sync_in_worker_thread\n",
741
+ " return await future\n",
742
+ " ^^^^^^^^^^^^\n",
743
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 914, in run\n",
744
+ " result = context.run(func, *args)\n",
745
+ " ^^^^^^^^^^^^^^^^^^^^^^^^\n",
746
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/utils.py\", line 869, in wrapper\n",
747
+ " response = f(*args, **kwargs)\n",
748
+ " ^^^^^^^^^^^^^^^^^^\n",
749
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 355, in load_example_with_output\n",
750
+ " ) + self.load_from_cache(example_id)\n",
751
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
752
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/helpers.py\", line 579, in load_from_cache\n",
753
+ " output.append(component.read_from_flag(value_to_use))\n",
754
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
755
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/components/base.py\", line 366, in read_from_flag\n",
756
+ " return self.data_model.from_json(json.loads(payload))\n",
757
+ " ^^^^^^^^^^^^^^^^^^^\n",
758
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/__init__.py\", line 346, in loads\n",
759
+ " return _default_decoder.decode(s)\n",
760
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
761
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 337, in decode\n",
762
+ " obj, end = self.raw_decode(s, idx=_w(s, 0).end())\n",
763
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
764
+ " File \"/home/frohlkin/anaconda3/envs/LLM/lib/python3.12/json/decoder.py\", line 355, in raw_decode\n",
765
+ " raise JSONDecodeError(\"Expecting value\", s, err.value) from None\n",
766
+ "json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)\n"
767
+ ]
768
+ }
769
+ ],
770
+ "source": [
771
+ "from datetime import datetime\n",
772
+ "import gradio as gr\n",
773
+ "import requests\n",
774
+ "from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select, Structure\n",
775
+ "from Bio.PDB.Polypeptide import is_aa\n",
776
+ "from Bio.SeqUtils import seq1\n",
777
+ "from typing import Optional, Tuple, Dict, List\n",
778
+ "import numpy as np\n",
779
+ "import os\n",
780
+ "from gradio_molecule3d import Molecule3D\n",
781
+ "import torch\n",
782
+ "import torch.nn as nn\n",
783
+ "import torch.nn.functional as F\n",
784
+ "from torch.utils.data import DataLoader\n",
785
+ "import re\n",
786
+ "import pandas as pd\n",
787
+ "import copy\n",
788
+ "from scipy.special import expit\n",
789
+ "import logging\n",
790
+ "import tempfile\n",
791
+ "\n",
792
+ "# Set up logging\n",
793
+ "logging.basicConfig(level=logging.INFO)\n",
794
+ "logger = logging.getLogger(__name__)\n",
795
+ "\n",
796
+ "class StructureError(Exception):\n",
797
+ " \"\"\"Custom exception for structure-related errors\"\"\"\n",
798
+ " pass\n",
799
+ "\n",
800
+ "def normalize_scores(scores: np.ndarray) -> np.ndarray:\n",
801
+ " \"\"\"Normalize scores to range [0,1]\"\"\"\n",
802
+ " min_score = np.min(scores)\n",
803
+ " max_score = np.max(scores)\n",
804
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
805
+ "\n",
806
+ "def read_mol(pdb_path: str) -> str:\n",
807
+ " \"\"\"Read molecular structure file and return its content\"\"\"\n",
808
+ " try:\n",
809
+ " with open(pdb_path, 'r') as f:\n",
810
+ " return f.read()\n",
811
+ " except Exception as e:\n",
812
+ " raise IOError(f\"Failed to read structure file: {e}\")\n",
813
+ "\n",
814
+ "def fetch_structure(pdb_id: str, output_dir: str = \".\") -> Optional[str]:\n",
815
+ " \"\"\"Fetch structure file, trying multiple formats and sources\"\"\"\n",
816
+ " try:\n",
817
+ " # First try local cache\n",
818
+ " for ext in ['.cif', '.pdb']:\n",
819
+ " local_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n",
820
+ " if os.path.exists(local_path):\n",
821
+ " logger.info(f\"Using cached structure: {local_path}\")\n",
822
+ " return local_path\n",
823
+ "\n",
824
+ " # Try different download sources\n",
825
+ " sources = [\n",
826
+ " f\"https://files.rcsb.org/download/{pdb_id.upper()}.cif\",\n",
827
+ " f\"https://files.rcsb.org/download/{pdb_id.upper()}.pdb\",\n",
828
+ " f\"https://files.rcsb.org/download/{pdb_id.lower()}.cif\",\n",
829
+ " f\"https://files.rcsb.org/download/{pdb_id.lower()}.pdb\"\n",
830
+ " ]\n",
831
+ "\n",
832
+ " for url in sources:\n",
833
+ " try:\n",
834
+ " response = requests.get(url, timeout=10)\n",
835
+ " if response.status_code == 200:\n",
836
+ " ext = '.cif' if 'cif' in url else '.pdb'\n",
837
+ " file_path = os.path.join(output_dir, f\"{pdb_id.lower()}{ext}\")\n",
838
+ " with open(file_path, 'wb') as f:\n",
839
+ " f.write(response.content)\n",
840
+ " logger.info(f\"Successfully downloaded: {url}\")\n",
841
+ " return file_path\n",
842
+ " except Exception as e:\n",
843
+ " logger.warning(f\"Failed to download from {url}: {e}\")\n",
844
+ " continue\n",
845
+ "\n",
846
+ " raise StructureError(f\"Failed to fetch structure for PDB ID: {pdb_id}\")\n",
847
+ " except Exception as e:\n",
848
+ " raise StructureError(f\"Error fetching structure: {e}\")\n",
849
+ "\n",
850
+ "def convert_cif_to_pdb(cif_path: str, output_dir: str = \".\") -> str:\n",
851
+ " \"\"\"Convert CIF to PDB format with error handling\"\"\"\n",
852
+ " try:\n",
853
+ " pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))\n",
854
+ " parser = MMCIFParser(QUIET=True)\n",
855
+ " structure = parser.get_structure('protein', cif_path)\n",
856
+ " io = PDBIO()\n",
857
+ " io.set_structure(structure)\n",
858
+ " io.save(pdb_path)\n",
859
+ " return pdb_path\n",
860
+ " except Exception as e:\n",
861
+ " raise StructureError(f\"Failed to convert CIF to PDB: {e}\")\n",
862
+ "\n",
863
+ "def find_valid_chain(structure: Structure.Structure) -> Optional[str]:\n",
864
+ " \"\"\"Find the first valid protein chain in the structure\"\"\"\n",
865
+ " for model in structure:\n",
866
+ " for chain in model:\n",
867
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
868
+ " if len(protein_residues) > 0:\n",
869
+ " return chain.id\n",
870
+ " return None\n",
871
+ "\n",
872
+ "def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:\n",
873
+ " \"\"\"Create PDB file with selected chain and prediction scores in B-factor column\"\"\"\n",
874
+ " class ResidueSelector(Select):\n",
875
+ " def __init__(self, chain_id, selected_residues, scores_dict):\n",
876
+ " self.chain_id = chain_id\n",
877
+ " self.selected_residues = selected_residues\n",
878
+ " self.scores_dict = scores_dict\n",
879
+ " \n",
880
+ " def accept_chain(self, chain):\n",
881
+ " return chain.id == self.chain_id\n",
882
+ " \n",
883
+ " def accept_residue(self, residue):\n",
884
+ " return residue.id[1] in self.selected_residues\n",
885
+ "\n",
886
+ " def accept_atom(self, atom):\n",
887
+ " if atom.parent.id[1] in self.scores_dict:\n",
888
+ " atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100\n",
889
+ " return True\n",
890
+ "\n",
891
+ " try:\n",
892
+ " parser = PDBParser(QUIET=True)\n",
893
+ " structure = parser.get_structure('protein', input_pdb)\n",
894
+ " output_pdb = f\"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb\"\n",
895
+ " scores_dict = {resi: score for resi, score in residue_scores}\n",
896
+ " \n",
897
+ " io = PDBIO()\n",
898
+ " selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)\n",
899
+ " io.set_structure(structure[0])\n",
900
+ " io.save(output_pdb, selector)\n",
901
+ " \n",
902
+ " return output_pdb\n",
903
+ " except Exception as e:\n",
904
+ " raise StructureError(f\"Failed to create chain-specific PDB: {e}\")\n",
905
+ "\n",
906
+ "def process_pdb(pdb_id_or_file: str, segment: str) -> Tuple[str, str, List[str]]:\n",
907
+ " \"\"\"Process PDB/CIF file and generate visualizations and predictions\"\"\"\n",
908
+ " try:\n",
909
+ " # Handle input\n",
910
+ " if pdb_id_or_file.endswith(('.pdb', '.cif')):\n",
911
+ " pdb_path = pdb_id_or_file\n",
912
+ " pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]\n",
913
+ " else:\n",
914
+ " pdb_id = pdb_id_or_file\n",
915
+ " pdb_path = fetch_structure(pdb_id)\n",
916
+ "\n",
917
+ " if not pdb_path:\n",
918
+ " raise StructureError(\"Failed to obtain structure file\")\n",
919
+ "\n",
920
+ " # Parse structure\n",
921
+ " parser = MMCIFParser(QUIET=True) if pdb_path.endswith('.cif') else PDBParser(QUIET=True)\n",
922
+ " structure = parser.get_structure('protein', pdb_path)\n",
923
+ "\n",
924
+ " # Handle chain selection\n",
925
+ " if segment == 'auto' or not segment:\n",
926
+ " segment = find_valid_chain(structure)\n",
927
+ " if not segment:\n",
928
+ " raise StructureError(\"No valid protein chains found in structure\")\n",
929
+ " \n",
930
+ " try:\n",
931
+ " chain = structure[0][segment]\n",
932
+ " except KeyError:\n",
933
+ " valid_chain = find_valid_chain(structure)\n",
934
+ " if valid_chain:\n",
935
+ " chain = structure[0][valid_chain]\n",
936
+ " segment = valid_chain\n",
937
+ " logger.info(f\"Using alternative chain {segment}\")\n",
938
+ " else:\n",
939
+ " raise StructureError(f\"Invalid chain ID '{segment}'. Structure has no valid protein chains.\")\n",
940
+ "\n",
941
+ " # Process chain\n",
942
+ " protein_residues = [res for res in chain if is_aa(res)]\n",
943
+ " if not protein_residues:\n",
944
+ " raise StructureError(f\"No amino acid residues found in chain {segment}\")\n",
945
+ "\n",
946
+ " sequence = \"\".join(seq1(res.resname) for res in protein_residues)\n",
947
+ " sequence_id = [res.id[1] for res in protein_residues]\n",
948
+ " \n",
949
+ " # Generate predictions (currently random)\n",
950
+ " scores = np.random.rand(len(sequence))\n",
951
+ " normalized_scores = normalize_scores(scores)\n",
952
+ " residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]\n",
953
+ "\n",
954
+ " # Generate outputs\n",
955
+ " result_str = generate_results_string(pdb_id, segment, protein_residues, normalized_scores, sequence)\n",
956
+ " scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)\n",
957
+ " mol_vis = molecule(pdb_path, residue_scores, segment)\n",
958
+ " pymol_commands = generate_pymol_commands(pdb_id, segment, residue_scores, pdb_path)\n",
959
+ "\n",
960
+ " # Save results\n",
961
+ " prediction_file = f\"{pdb_id}_binding_site_residues.txt\"\n",
962
+ " with open(prediction_file, \"w\") as f:\n",
963
+ " f.write(result_str)\n",
964
+ "\n",
965
+ " return pymol_commands, mol_vis, [prediction_file, scored_pdb]\n",
966
+ "\n",
967
+ " except StructureError as e:\n",
968
+ " return str(e), None, None\n",
969
+ " except Exception as e:\n",
970
+ " return f\"An unexpected error occurred: {str(e)}\", None, None\n",
971
+ "\n",
972
+ "def generate_results_string(pdb_id: str, segment: str, protein_residues: list, \n",
973
+ " normalized_scores: np.ndarray, sequence: str) -> str:\n",
974
+ " \"\"\"Generate formatted results string with predictions\"\"\"\n",
975
+ " score_brackets = {\n",
976
+ " \"0.0-0.2\": (0.0, 0.2),\n",
977
+ " \"0.2-0.4\": (0.2, 0.4),\n",
978
+ " \"0.4-0.6\": (0.4, 0.6),\n",
979
+ " \"0.6-0.8\": (0.6, 0.8),\n",
980
+ " \"0.8-1.0\": (0.8, 1.0)\n",
981
+ " }\n",
982
+ " \n",
983
+ " residues_by_bracket = {bracket: [] for bracket in score_brackets}\n",
984
+ " \n",
985
+ " # Categorize residues\n",
986
+ " for i, score in enumerate(normalized_scores):\n",
987
+ " for bracket, (lower, upper) in score_brackets.items():\n",
988
+ " if lower <= score < upper:\n",
989
+ " residues_by_bracket[bracket].append(protein_residues[i])\n",
990
+ " break\n",
991
+ " \n",
992
+ " # Format results\n",
993
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
994
+ " result_str = f\"\"\"Prediction Results\n",
995
+ "========================\n",
996
+ "PDB: {pdb_id}\n",
997
+ "Chain: {segment}\n",
998
+ "Date: {current_time}\n",
999
+ "\n",
1000
+ "Analysis by Score Brackets\n",
1001
+ "========================\n",
1002
+ "\"\"\"\n",
1003
+ " \n",
1004
+ " for bracket, residues in residues_by_bracket.items():\n",
1005
+ " if residues: # Only show brackets with residues\n",
1006
+ " result_str += f\"\\nBracket {bracket}:\\n\"\n",
1007
+ " result_str += \"ResName ResNum Code Score\\n\"\n",
1008
+ " result_str += \"-\" * 30 + \"\\n\"\n",
1009
+ " result_str += \"\\n\".join([\n",
1010
+ " f\"{res.resname:6} {res.id[1]:6} {sequence[i]:4} {normalized_scores[i]:6.2f}\" \n",
1011
+ " for i, res in enumerate(protein_residues) if res in residues\n",
1012
+ " ])\n",
1013
+ " result_str += \"\\n\"\n",
1014
+ " \n",
1015
+ " return result_str\n",
1016
+ "\n",
1017
+ "def generate_pymol_commands(pdb_id: str, segment: str, residue_scores: list, pdb_path: str) -> str:\n",
1018
+ " \"\"\"Generate PyMOL visualization commands\"\"\"\n",
1019
+ " # Group residues by score ranges\n",
1020
+ " score_groups = {\n",
1021
+ " \"very_low\": [], \"low\": [], \"medium\": [], \"high\": [], \"very_high\": []\n",
1022
+ " }\n",
1023
+ " \n",
1024
+ " for resi, score in residue_scores:\n",
1025
+ " if score <= 0.2:\n",
1026
+ " score_groups[\"very_low\"].append(str(resi))\n",
1027
+ " elif score <= 0.4:\n",
1028
+ " score_groups[\"low\"].append(str(resi))\n",
1029
+ " elif score <= 0.6:\n",
1030
+ " score_groups[\"medium\"].append(str(resi))\n",
1031
+ " elif score <= 0.8:\n",
1032
+ " score_groups[\"high\"].append(str(resi))\n",
1033
+ " else:\n",
1034
+ " score_groups[\"very_high\"].append(str(resi))\n",
1035
+ "\n",
1036
+ " current_time = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
1037
+ " commands = f\"\"\"# PyMOL Script for {pdb_id} Chain {segment}\n",
1038
+ "# Generated: {current_time}\n",
1039
+ "\n",
1040
+ "# Load structure and set initial display\n",
1041
+ "load {os.path.abspath(pdb_path)}, protein\n",
1042
+ "bg_color white\n",
1043
+ "hide everything\n",
1044
+ "show cartoon, chain {segment}\n",
1045
+ "color white, chain {segment}\n",
1046
+ "\n",
1047
+ "# Create selection groups by score\n",
1048
+ "\"\"\"\n",
1049
+ " \n",
1050
+ " color_scheme = {\n",
1051
+ " \"very_low\": \"white\",\n",
1052
+ " \"low\": \"lightorange\",\n",
1053
+ " \"medium\": \"orange\",\n",
1054
+ " \"high\": \"orangered\",\n",
1055
+ " \"very_high\": \"red\"\n",
1056
+ " }\n",
1057
+ " \n",
1058
+ " for group, residues in score_groups.items():\n",
1059
+ " if residues:\n",
1060
+ " resi_str = \"+\".join(residues)\n",
1061
+ " commands += f\"\"\"\n",
1062
+ "# {group.replace('_', ' ').title()} scoring residues\n",
1063
+ "select {group}, chain {segment} and resi {resi_str}\n",
1064
+ "show sticks, {group}\n",
1065
+ "color {color_scheme[group]}, {group}\"\"\"\n",
1066
+ " \n",
1067
+ " commands += \"\"\"\n",
1068
+ "\n",
1069
+ "# Center and zoom\n",
1070
+ "center chain {}\n",
1071
+ "zoom chain {}\n",
1072
+ "\"\"\"\n",
1073
+ "\n",
1074
+ " return commands\n",
1075
+ "\n",
1076
+ "def molecule(input_pdb: str, residue_scores: list = None, segment: str = 'A') -> str:\n",
1077
+ " \"\"\"Generate interactive 3D molecule visualization\"\"\"\n",
1078
+ " try:\n",
1079
+ " mol = read_mol(input_pdb)\n",
1080
+ " except Exception as e:\n",
1081
+ " return f'<div class=\"error\">Error loading structure: {str(e)}</div>'\n",
1082
+ "\n",
1083
+ " # Prepare residue groups for visualization\n",
1084
+ " vis_groups = {\n",
1085
+ " \"class1\": [], # 0.0-0.2\n",
1086
+ " \"class2\": [], # 0.2-0.4\n",
1087
+ " \"class3\": [], # 0.4-0.6\n",
1088
+ " \"class4\": [], # 0.6-0.8\n",
1089
+ " \"class5\": [] # 0.8-1.0\n",
1090
+ " }\n",
1091
+ "\n",
1092
+ " if residue_scores:\n",
1093
+ " for resi, score in residue_scores:\n",
1094
+ " if score <= 0.2:\n",
1095
+ " vis_groups[\"class1\"].append(resi)\n",
1096
+ " elif score <= 0.4:\n",
1097
+ " vis_groups[\"class2\"].append(resi)\n",
1098
+ " elif score <= 0.6:\n",
1099
+ " vis_groups[\"class3\"].append(resi)\n",
1100
+ " elif score <= 0.8:\n",
1101
+ " vis_groups[\"class4\"].append(resi)\n",
1102
+ " else:\n",
1103
+ " vis_groups[\"class5\"].append(resi)\n",
1104
+ "\n",
1105
+ " # Generate visualization script\n",
1106
+ " vis_script = f\"\"\"\n",
1107
+ " // Base model setup\n",
1108
+ " let chainModel = viewer.addModel(pdb, \"pdb\");\n",
1109
+ " chainModel.setStyle({{}}, {{}});\n",
1110
+ " chainModel.setStyle(\n",
1111
+ " {{\"chain\": \"{segment}\"}}, \n",
1112
+ " {{\"cartoon\": {{\"color\": \"white\"}}}}\n",
1113
+ " );\n",
1114
+ " \"\"\"\n",
1115
+ "\n",
1116
+ " # Color schemes for different score ranges\n",
1117
+ " color_schemes = {\n",
1118
+ " \"class1\": {\"color\": \"0xFFFFFF\", \"opacity\": 0.5}, # White\n",
1119
+ " \"class2\": {\"color\": \"0xFFD580\", \"opacity\": 0.7}, # Light orange\n",
1120
+ " \"class3\": {\"color\": \"0xFFA500\", \"opacity\": 1.0}, # Orange\n",
1121
+ " \"class4\": {\"color\": \"0xFF4500\", \"opacity\": 1.0}, # Orange red\n",
1122
+ " \"class5\": {\"color\": \"0xFF0000\", \"opacity\": 1.0} # Red\n",
1123
+ " }\n",
1124
+ "\n",
1125
+ " # Add visualization for each group\n",
1126
+ " for group, residues in vis_groups.items():\n",
1127
+ " if residues:\n",
1128
+ " color_scheme = color_schemes[group]\n",
1129
+ " vis_script += f\"\"\"\n",
1130
+ " let {group}Model = viewer.addModel(pdb, \"pdb\");\n",
1131
+ " {group}Model.setStyle({{}}, {{}});\n",
1132
+ " {group}Model.setStyle(\n",
1133
+ " {{\"chain\": \"{segment}\", \"resi\": [{\", \".join(map(str, residues))}]}},\n",
1134
+ " {{\"stick\": {{\"color\": \"{color_scheme[\"color\"]}\", \"opacity\": {color_scheme[\"opacity\"]}}}}}\n",
1135
+ " );\n",
1136
+ " \"\"\"\n",
1137
+ "\n",
1138
+ " # Generate full HTML with enhanced controls and information\n",
1139
+ " html_content = f\"\"\"\n",
1140
+ " <!DOCTYPE html>\n",
1141
+ " <html>\n",
1142
+ " <head> \n",
1143
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
1144
+ " <style>\n",
1145
+ " .mol-container {{\n",
1146
+ " width: 100%;\n",
1147
+ " height: 700px;\n",
1148
+ " position: relative;\n",
1149
+ " }}\n",
1150
+ " .controls {{\n",
1151
+ " position: absolute;\n",
1152
+ " top: 10px;\n",
1153
+ " left: 10px;\n",
1154
+ " background: rgba(255, 255, 255, 0.8);\n",
1155
+ " padding: 10px;\n",
1156
+ " border-radius: 5px;\n",
1157
+ " z-index: 1000;\n",
1158
+ " }}\n",
1159
+ " .legend {{\n",
1160
+ " position: absolute;\n",
1161
+ " bottom: 10px;\n",
1162
+ " right: 10px;\n",
1163
+ " background: rgba(255, 255, 255, 0.8);\n",
1164
+ " padding: 10px;\n",
1165
+ " border-radius: 5px;\n",
1166
+ " z-index: 1000;\n",
1167
+ " }}\n",
1168
+ " .error {{\n",
1169
+ " color: red;\n",
1170
+ " padding: 20px;\n",
1171
+ " text-align: center;\n",
1172
+ " font-weight: bold;\n",
1173
+ " }}\n",
1174
+ " </style>\n",
1175
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
1176
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
1177
+ " </head>\n",
1178
+ " <body>\n",
1179
+ " <div id=\"container\" class=\"mol-container\">\n",
1180
+ " <div class=\"controls\">\n",
1181
+ " <button onclick=\"toggleStyle('cartoon')\">Toggle Cartoon</button>\n",
1182
+ " <button onclick=\"toggleStyle('stick')\">Toggle Sticks</button>\n",
1183
+ " <button onclick=\"resetView()\">Reset View</button>\n",
1184
+ " <button onclick=\"toggleSpin()\">Toggle Spin</button>\n",
1185
+ " </div>\n",
1186
+ " <div class=\"legend\">\n",
1187
+ " <div><span style=\"color: #FF0000\">■</span> Very High (0.8-1.0)</div>\n",
1188
+ " <div><span style=\"color: #FF4500\">■</span> High (0.6-0.8)</div>\n",
1189
+ " <div><span style=\"color: #FFA500\">■</span> Medium (0.4-0.6)</div>\n",
1190
+ " <div><span style=\"color: #FFD580\">■</span> Low (0.2-0.4)</div>\n",
1191
+ " <div><span style=\"color: #FFFFFF\">■</span> Very Low (0.0-0.2)</div>\n",
1192
+ " </div>\n",
1193
+ " </div>\n",
1194
+ " <script>\n",
1195
+ " let pdb = `{mol}`;\n",
1196
+ " let viewer;\n",
1197
+ " let isSpinning = false;\n",
1198
+ "\n",
1199
+ " $(document).ready(function () {{\n",
1200
+ " let element = $(\"#container\");\n",
1201
+ " let config = {{ backgroundColor: \"white\" }};\n",
1202
+ " viewer = $3Dmol.createViewer(element, config);\n",
1203
+ " \n",
1204
+ " {vis_script}\n",
1205
+ " \n",
1206
+ " // Enhanced hover functionality\n",
1207
+ " viewer.setHoverable(\n",
1208
+ " {{}}, \n",
1209
+ " true, \n",
1210
+ " function(atom, viewer, event, container) {{\n",
1211
+ " if (!atom.label) {{\n",
1212
+ " atom.label = viewer.addLabel(\n",
1213
+ " `${{atom.resn}}:${{atom.resi}}:${{atom.atom}}`, \n",
1214
+ " {{\n",
1215
+ " position: atom, \n",
1216
+ " backgroundColor: 'mintcream', \n",
1217
+ " fontColor: 'black',\n",
1218
+ " fontSize: 18,\n",
1219
+ " padding: 4\n",
1220
+ " }}\n",
1221
+ " );\n",
1222
+ " }}\n",
1223
+ " }},\n",
1224
+ " function(atom, viewer) {{\n",
1225
+ " if (atom.label) {{\n",
1226
+ " viewer.removeLabel(atom.label);\n",
1227
+ " delete atom.label;\n",
1228
+ " }}\n",
1229
+ " }}\n",
1230
+ " );\n",
1231
+ " \n",
1232
+ " viewer.zoomTo();\n",
1233
+ " viewer.render();\n",
1234
+ " viewer.zoom(0.8, 2000);\n",
1235
+ " }});\n",
1236
+ "\n",
1237
+ " function toggleStyle(style) {{\n",
1238
+ " let elements = viewer.selectedAtoms({{}});\n",
1239
+ " let currentStyle = elements.style[style];\n",
1240
+ " elements.style[style] = !currentStyle;\n",
1241
+ " viewer.render();\n",
1242
+ " }}\n",
1243
+ "\n",
1244
+ " function resetView() {{\n",
1245
+ " viewer.zoomTo();\n",
1246
+ " viewer.render();\n",
1247
+ " }}\n",
1248
+ "\n",
1249
+ " function toggleSpin() {{\n",
1250
+ " isSpinning = !isSpinning;\n",
1251
+ " viewer.spin(isSpinning);\n",
1252
+ " }}\n",
1253
+ " </script>\n",
1254
+ " </body>\n",
1255
+ " </html>\n",
1256
+ " \"\"\"\n",
1257
+ " \n",
1258
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
1259
+ "\n",
1260
+ "# Gradio UI\n",
1261
+ "def create_ui():\n",
1262
+ " with gr.Blocks(title=\"Protein Binding Site Prediction\", theme=gr.themes.Base()) as demo:\n",
1263
+ " gr.Markdown(\"\"\"\n",
1264
+ " # Protein Binding Site Prediction\n",
1265
+ " \n",
1266
+ " This tool helps you visualize and analyze potential binding sites in protein structures.\n",
1267
+ " You can either:\n",
1268
+ " 1. Enter a PDB ID (e.g., \"4BDU\")\n",
1269
+ " 2. Upload your own PDB/CIF file\n",
1270
+ " \n",
1271
+ " The tool will analyze the structure and show predictions using a color gradient from white (low probability) to red (high probability).\n",
1272
+ " \"\"\")\n",
1273
+ " \n",
1274
+ " with gr.Row():\n",
1275
+ " with gr.Column(scale=2):\n",
1276
+ " # Input components\n",
1277
+ " mode = gr.Radio(\n",
1278
+ " choices=[\"PDB ID\", \"Upload File\"],\n",
1279
+ " value=\"PDB ID\",\n",
1280
+ " label=\"Input Mode\",\n",
1281
+ " info=\"Choose whether to input a PDB ID or upload a PDB/CIF file\"\n",
1282
+ " )\n",
1283
+ " \n",
1284
+ " with gr.Group():\n",
1285
+ " pdb_input = gr.Textbox(\n",
1286
+ " value=\"4BDU\",\n",
1287
+ " label=\"PDB ID\",\n",
1288
+ " placeholder=\"Enter PDB ID (e.g., 4BDU)\",\n",
1289
+ " info=\"Enter a valid PDB ID from the Protein Data Bank\"\n",
1290
+ " )\n",
1291
+ " pdb_file = gr.File(\n",
1292
+ " label=\"Upload PDB/CIF File\",\n",
1293
+ " file_types=[\".pdb\", \".cif\"],\n",
1294
+ " visible=False\n",
1295
+ " )\n",
1296
+ " \n",
1297
+ " segment_input = gr.Textbox(\n",
1298
+ " value=\"A\",\n",
1299
+ " label=\"Chain ID\",\n",
1300
+ " placeholder=\"Enter Chain ID or leave empty for automatic selection\",\n",
1301
+ " info=\"Specify which protein chain to analyze, or leave empty for automatic selection\"\n",
1302
+ " )\n",
1303
+ "\n",
1304
+ " with gr.Column(scale=1):\n",
1305
+ " visualize_btn = gr.Button(\"Visualize Structure\", variant=\"primary\")\n",
1306
+ " prediction_btn = gr.Button(\"Predict Binding Site\", variant=\"secondary\")\n",
1307
+ " \n",
1308
+ " gr.Markdown(\"\"\"\n",
1309
+ " ### Color Legend\n",
1310
+ " - White: Very Low (0.0-0.2)\n",
1311
+ " - Light Orange: Low (0.2-0.4)\n",
1312
+ " - Orange: Medium (0.4-0.6)\n",
1313
+ " - Orange Red: High (0.6-0.8)\n",
1314
+ " - Red: Very High (0.8-1.0)\n",
1315
+ " \"\"\")\n",
1316
+ "\n",
1317
+ " with gr.Tab(\"3D Visualization\"):\n",
1318
+ " molecule_output = gr.HTML(label=\"Interactive 3D Structure\")\n",
1319
+ " \n",
1320
+ " with gr.Tab(\"Analysis Results\"):\n",
1321
+ " predictions_output = gr.Textbox(\n",
1322
+ " label=\"PyMOL Visualization Commands\",\n",
1323
+ " info=\"Copy these commands into PyMOL to recreate the visualization\"\n",
1324
+ " )\n",
1325
+ " download_output = gr.File(\n",
1326
+ " label=\"Download Results\",\n",
1327
+ " file_count=\"multiple\"\n",
1328
+ " )\n",
1329
+ "\n",
1330
+ " # Error message container\n",
1331
+ " error_output = gr.Markdown(visible=False)\n",
1332
+ "\n",
1333
+ " # Mode change handler\n",
1334
+ " def toggle_mode(selected_mode):\n",
1335
+ " return {\n",
1336
+ " pdb_input: gr.update(visible=selected_mode == \"PDB ID\"),\n",
1337
+ " pdb_file: gr.update(visible=selected_mode == \"Upload File\")\n",
1338
+ " }\n",
1339
+ "\n",
1340
+ " mode.change(\n",
1341
+ " toggle_mode,\n",
1342
+ " inputs=[mode],\n",
1343
+ " outputs=[pdb_input, pdb_file]\n",
1344
+ " )\n",
1345
+ "\n",
1346
+ " # Process handlers\n",
1347
+ " def handle_visualization(mode, pdb_id, pdb_file):\n",
1348
+ " try:\n",
1349
+ " result = fetch_interface(mode, pdb_id, pdb_file)\n",
1350
+ " if isinstance(result, str) and result.startswith(\"Error\"):\n",
1351
+ " return None, gr.update(visible=True, value=f\"```\\n{result}\\n```\")\n",
1352
+ " return result, gr.update(visible=False)\n",
1353
+ " except Exception as e:\n",
1354
+ " return None, gr.update(visible=True, value=f\"```\\nError: {str(e)}\\n```\")\n",
1355
+ "\n",
1356
+ " def handle_prediction(mode, pdb_id, pdb_file, chain_id):\n",
1357
+ " try:\n",
1358
+ " predictions, vis, downloads = process_interface(mode, pdb_id, pdb_file, chain_id)\n",
1359
+ " if isinstance(predictions, str) and \"Error\" in predictions:\n",
1360
+ " return (\n",
1361
+ " None,\n",
1362
+ " None,\n",
1363
+ " None,\n",
1364
+ " gr.update(visible=True, value=f\"```\\n{predictions}\\n```\")\n",
1365
+ " )\n",
1366
+ " return (\n",
1367
+ " predictions,\n",
1368
+ " vis,\n",
1369
+ " downloads,\n",
1370
+ " gr.update(visible=False)\n",
1371
+ " )\n",
1372
+ " except Exception as e:\n",
1373
+ " error_msg = f\"\"\"Error processing structure:\n",
1374
+ "```\n",
1375
+ "{str(e)}\n",
1376
+ "\n",
1377
+ "Troubleshooting tips:\n",
1378
+ "1. Check if the PDB ID is valid\n",
1379
+ "2. Ensure the Chain ID exists in the structure\n",
1380
+ "3. Try leaving Chain ID empty for automatic selection\n",
1381
+ "4. If uploading a file, ensure it's a valid PDB/CIF format\n",
1382
+ "```\"\"\"\n",
1383
+ " return None, None, None, gr.update(visible=True, value=error_msg)\n",
1384
+ "\n",
1385
+ " def fetch_interface(mode, pdb_id, pdb_file):\n",
1386
+ " try:\n",
1387
+ " if mode == \"PDB ID\":\n",
1388
+ " if not pdb_id or len(pdb_id.strip()) != 4:\n",
1389
+ " raise ValueError(\"Please enter a valid 4-character PDB ID\")\n",
1390
+ " return fetch_pdb(pdb_id.strip())\n",
1391
+ " elif mode == \"Upload File\":\n",
1392
+ " if not pdb_file:\n",
1393
+ " raise ValueError(\"Please upload a PDB or CIF file\")\n",
1394
+ " _, ext = os.path.splitext(pdb_file.name)\n",
1395
+ " if ext.lower() not in ['.pdb', '.cif']:\n",
1396
+ " raise ValueError(\"Only .pdb and .cif files are supported\")\n",
1397
+ " \n",
1398
+ " # Create temp directory for file handling\n",
1399
+ " with tempfile.TemporaryDirectory() as temp_dir:\n",
1400
+ " temp_path = os.path.join(temp_dir, os.path.basename(pdb_file.name))\n",
1401
+ " with open(temp_path, 'wb') as f:\n",
1402
+ " f.write(pdb_file.read())\n",
1403
+ " \n",
1404
+ " if ext.lower() == '.cif':\n",
1405
+ " return convert_cif_to_pdb(temp_path)\n",
1406
+ " return temp_path\n",
1407
+ " else:\n",
1408
+ " raise ValueError(\"Invalid mode selected\")\n",
1409
+ " except Exception as e:\n",
1410
+ " return f\"Error: {str(e)}\"\n",
1411
+ "\n",
1412
+ " # Connect event handlers\n",
1413
+ " visualize_btn.click(\n",
1414
+ " handle_visualization,\n",
1415
+ " inputs=[mode, pdb_input, pdb_file],\n",
1416
+ " outputs=[molecule_output, error_output]\n",
1417
+ " )\n",
1418
+ "\n",
1419
+ " prediction_btn.click(\n",
1420
+ " handle_prediction,\n",
1421
+ " inputs=[mode, pdb_input, pdb_file, segment_input],\n",
1422
+ " outputs=[predictions_output, molecule_output, download_output, error_output]\n",
1423
+ " )\n",
1424
+ "\n",
1425
+ " # Add examples\n",
1426
+ " gr.Examples(\n",
1427
+ " examples=[\n",
1428
+ " [\"PDB ID\", \"7RPZ\", None, \"A\"],\n",
1429
+ " [\"PDB ID\", \"2IWI\", None, \"B\"],\n",
1430
+ " [\"PDB ID\", \"2F6V\", None, \"A\"]\n",
1431
+ " ],\n",
1432
+ " inputs=[mode, pdb_input, pdb_file, segment_input],\n",
1433
+ " outputs=[predictions_output, molecule_output, download_output, error_output],\n",
1434
+ " fn=handle_prediction,\n",
1435
+ " cache_examples=True\n",
1436
+ " )\n",
1437
+ "\n",
1438
+ " # Add documentation\n",
1439
+ " gr.Markdown(\"\"\"\n",
1440
+ " ## Usage Instructions\n",
1441
+ " \n",
1442
+ " 1. **Input Structure:**\n",
1443
+ " - Enter a PDB ID (e.g., \"4BDU\") or upload your own structure file\n",
1444
+ " - The tool supports both PDB (.pdb) and mmCIF (.cif) formats\n",
1445
+ " \n",
1446
+ " 2. **Select Chain:**\n",
1447
+ " - Enter a specific chain ID (e.g., \"A\")\n",
1448
+ " - Leave empty for automatic selection of the first valid protein chain\n",
1449
+ " \n",
1450
+ " 3. **Analyze:**\n",
1451
+ " - Click \"Visualize Structure\" to view the 3D structure\n",
1452
+ " - Click \"Predict Binding Site\" to perform binding site analysis\n",
1453
+ " \n",
1454
+ " 4. **Results:**\n",
1455
+ " - Interactive 3D visualization with color-coded predictions\n",
1456
+ " - PyMOL commands for external visualization\n",
1457
+ " - Downloadable results files\n",
1458
+ " \n",
1459
+ " ## Troubleshooting\n",
1460
+ " \n",
1461
+ " If you encounter issues:\n",
1462
+ " 1. Ensure your PDB ID is valid and exists in the PDB database\n",
1463
+ " 2. Check that your uploaded file is a valid PDB/CIF format\n",
1464
+ " 3. Try automatic chain selection if your specified chain isn't found\n",
1465
+ " 4. Clear your browser cache if visualizations don't load\n",
1466
+ " \"\"\")\n",
1467
+ "\n",
1468
+ " return demo\n",
1469
+ "\n",
1470
+ "if __name__ == \"__main__\":\n",
1471
+ " demo = create_ui()\n",
1472
+ " demo.launch(share=True)"
1473
+ ]
1474
+ },
1475
+ {
1476
+ "cell_type": "code",
1477
+ "execution_count": null,
1478
+ "id": "9125d1c4-e2ae-4e40-ba36-7ae944512b8e",
1479
+ "metadata": {},
1480
+ "outputs": [],
1481
+ "source": []
1482
+ },
1483
+ {
1484
+ "cell_type": "code",
1485
+ "execution_count": null,
1486
+ "id": "85c0728a-a15b-4118-b920-5f55a2f5f79a",
1487
+ "metadata": {},
1488
+ "outputs": [],
1489
+ "source": []
1490
+ }
1491
+ ],
1492
+ "metadata": {
1493
+ "kernelspec": {
1494
+ "display_name": "Python (LLM)",
1495
+ "language": "python",
1496
+ "name": "llm"
1497
+ },
1498
+ "language_info": {
1499
+ "codemirror_mode": {
1500
+ "name": "ipython",
1501
+ "version": 3
1502
+ },
1503
+ "file_extension": ".py",
1504
+ "mimetype": "text/x-python",
1505
+ "name": "python",
1506
+ "nbconvert_exporter": "python",
1507
+ "pygments_lexer": "ipython3",
1508
+ "version": "3.12.2"
1509
+ }
1510
+ },
1511
+ "nbformat": 4,
1512
+ "nbformat_minor": 5
1513
+ }