ThorbenFroehlking commited on
Commit
c85a5b0
·
1 Parent(s): 8bef2d8
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -1,4 +1,11 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
  from model_loader import load_model
3
 
4
  import torch
@@ -7,8 +14,6 @@ import torch.nn.functional as F
7
  from torch.utils.data import DataLoader
8
 
9
  import re
10
- import numpy as np
11
- import os
12
  import pandas as pd
13
  import copy
14
 
@@ -20,18 +25,6 @@ from datasets import Dataset
20
 
21
  from scipy.special import expit
22
 
23
- import requests
24
-
25
- from gradio_molecule3d import Molecule3D
26
-
27
- # Biopython imports
28
- from Bio.PDB import PDBParser, Select, PDBIO
29
- from Bio.PDB.DSSP import DSSP
30
- from Bio.PDB import PDBList
31
-
32
- from matplotlib import cm # For color mapping
33
- from matplotlib.colors import Normalize
34
-
35
  # Load model and move to device
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
@@ -40,23 +33,26 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  model.to(device)
41
  model.eval()
42
 
43
- # Function to fetch a PDB file
 
 
 
 
 
 
 
 
 
44
  def fetch_pdb(pdb_id):
45
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
46
- pdb_path = f'pdb_files/{pdb_id}.pdb'
47
- os.makedirs('pdb_files', exist_ok=True)
48
  response = requests.get(pdb_url)
49
  if response.status_code == 200:
50
  with open(pdb_path, 'wb') as f:
51
  f.write(response.content)
52
  return pdb_path
53
- return None
54
-
55
-
56
- def normalize_scores(scores):
57
- min_score = np.min(scores)
58
- max_score = np.max(scores)
59
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
60
 
61
  def process_pdb(pdb_id, segment):
62
  pdb_path = fetch_pdb(pdb_id)
@@ -65,7 +61,11 @@ def process_pdb(pdb_id, segment):
65
 
66
  parser = PDBParser(QUIET=1)
67
  structure = parser.get_structure('protein', pdb_path)
68
- chain = structure[0][segment]
 
 
 
 
69
 
70
  # Comprehensive amino acid mapping
71
  aa_dict = {
@@ -77,11 +77,10 @@ def process_pdb(pdb_id, segment):
77
  }
78
 
79
  # Exclude non-amino acid residues
80
- sequence = "".join(
81
- aa_dict[residue.get_resname().strip()]
82
- for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
- )
85
 
86
  # Prepare input for model prediction
87
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
@@ -91,54 +90,166 @@ def process_pdb(pdb_id, segment):
91
  # Calculate scores and normalize them
92
  scores = expit(outputs[:, 1] - outputs[:, 0])
93
  normalized_scores = normalize_scores(scores)
 
 
 
 
 
94
 
95
- # Prepare the result string, including only amino acid residues
96
- result_str = "\n".join([
97
- f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
98
- for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
99
- ])
100
-
101
- # Save predictions to file
102
- with open(f"{pdb_id}_predictions.txt", "w") as f:
103
  f.write(result_str)
104
 
105
- return result_str, pdb_path, f"{pdb_id}_predictions.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Gradio UI
110
  with gr.Blocks() as demo:
111
- gr.Markdown("# Protein Binding Site Prediction")
 
 
 
 
 
112
 
113
  with gr.Row():
114
- pdb_input = gr.Textbox(value="2IWI",
115
- label="PDB ID",
116
- placeholder="Enter PDB ID here...")
117
- segment_input = gr.Textbox(value="A",
118
- label="Chain ID (Segment)",
119
- placeholder="Enter Chain ID here...")
120
- visualize_btn = gr.Button("Visualize Sructure")
121
- prediction_btn = gr.Button("Predict Ligand Binding Site")
122
-
123
- molecule_output = Molecule3D(label="Protein Structure", reps=reps)
124
  predictions_output = gr.Textbox(label="Binding Site Predictions")
125
  download_output = gr.File(label="Download Predictions")
126
-
127
- visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
128
- prediction_btn.click(
129
- process_pdb,
130
- inputs=[pdb_input, segment_input],
131
- outputs=[predictions_output, molecule_output, download_output]
132
- )
133
-
134
  gr.Markdown("## Examples")
135
  gr.Examples(
136
  examples=[
137
- ["2IWI"],
138
- ["7RPZ"],
139
- ["3TJN"]
140
  ],
141
- inputs=[pdb_input, segment_input],
142
  outputs=[predictions_output, molecule_output, download_output]
143
  )
144
 
 
1
  import gradio as gr
2
+ import requests
3
+ from Bio.PDB import PDBParser
4
+ import numpy as np
5
+ import os
6
+ from gradio_molecule3d import Molecule3D
7
+
8
+
9
  from model_loader import load_model
10
 
11
  import torch
 
14
  from torch.utils.data import DataLoader
15
 
16
  import re
 
 
17
  import pandas as pd
18
  import copy
19
 
 
25
 
26
  from scipy.special import expit
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load model and move to device
29
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
30
  max_length = 1500
 
33
  model.to(device)
34
  model.eval()
35
 
36
+ def normalize_scores(scores):
37
+ min_score = np.min(scores)
38
+ max_score = np.max(scores)
39
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
40
+
41
+ def read_mol(pdb_path):
42
+ """Read PDB file and return its content as a string"""
43
+ with open(pdb_path, 'r') as f:
44
+ return f.read()
45
+
46
  def fetch_pdb(pdb_id):
47
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
48
+ pdb_path = f'{pdb_id}.pdb'
 
49
  response = requests.get(pdb_url)
50
  if response.status_code == 200:
51
  with open(pdb_path, 'wb') as f:
52
  f.write(response.content)
53
  return pdb_path
54
+ else:
55
+ return None
 
 
 
 
 
56
 
57
  def process_pdb(pdb_id, segment):
58
  pdb_path = fetch_pdb(pdb_id)
 
61
 
62
  parser = PDBParser(QUIET=1)
63
  structure = parser.get_structure('protein', pdb_path)
64
+
65
+ try:
66
+ chain = structure[0][segment]
67
+ except KeyError:
68
+ return "Invalid Chain ID", None, None
69
 
70
  # Comprehensive amino acid mapping
71
  aa_dict = {
 
77
  }
78
 
79
  # Exclude non-amino acid residues
80
+ sequence = [
81
+ residue for residue in chain
 
82
  if residue.get_resname().strip() in aa_dict
83
+ ]
84
 
85
  # Prepare input for model prediction
86
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
 
90
  # Calculate scores and normalize them
91
  scores = expit(outputs[:, 1] - outputs[:, 0])
92
  normalized_scores = normalize_scores(scores)
93
+
94
+ result_str = "\n".join(
95
+ f"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}"
96
+ for res, score in zip(sequence, normalized_scores)
97
+ )
98
 
99
+ # Save the predictions to a file
100
+ prediction_file = f"{pdb_id}_predictions.txt"
101
+ with open(prediction_file, "w") as f:
 
 
 
 
 
102
  f.write(result_str)
103
 
104
+ return result_str, molecule(pdb_path, random_scores, segment), prediction_file
105
+
106
+ def molecule(input_pdb, scores=None, segment='A'):
107
+ mol = read_mol(input_pdb) # Read PDB file content
108
+
109
+ # Prepare high-scoring residues script if scores are provided
110
+ high_score_script = ""
111
+ if scores is not None:
112
+ high_score_script = """
113
+ // Reset all styles first
114
+ viewer.getModel(0).setStyle({}, {});
115
+
116
+ // Show only the selected chain
117
+ viewer.getModel(0).setStyle(
118
+ {"chain": "%s"},
119
+ { cartoon: {colorscheme:"whiteCarbon"} }
120
+ );
121
+
122
+ // Highlight high-scoring residues only for the selected chain
123
+ let highScoreResidues = [%s];
124
+ viewer.getModel(0).setStyle(
125
+ {"chain": "%s", "resi": highScoreResidues},
126
+ {"stick": {"color": "red"}}
127
+ );
128
 
129
+ // Highlight high-scoring residues only for the selected chain
130
+ let highScoreResidues2 = [%s];
131
+ viewer.getModel(0).setStyle(
132
+ {"chain": "%s", "resi": highScoreResidues2},
133
+ {"stick": {"color": "orange"}}
134
+ );
135
+ """ % (segment,
136
+ ", ".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),
137
+ segment,
138
+ ", ".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),
139
+ segment)
140
+
141
+ html_content = f"""
142
+ <!DOCTYPE html>
143
+ <html>
144
+ <head>
145
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
146
+ <style>
147
+ .mol-container {{
148
+ width: 100%;
149
+ height: 700px;
150
+ position: relative;
151
+ }}
152
+ </style>
153
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script>
154
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
155
+ </head>
156
+ <body>
157
+ <div id="container" class="mol-container"></div>
158
+ <script>
159
+ let pdb = `{mol}`; // Use template literal to properly escape PDB content
160
+ $(document).ready(function () {{
161
+ let element = $("#container");
162
+ let config = {{ backgroundColor: "white" }};
163
+ let viewer = $3Dmol.createViewer(element, config);
164
+ viewer.addModel(pdb, "pdb");
165
+
166
+ // Reset all styles and show only selected chain
167
+ viewer.getModel(0).setStyle(
168
+ {{"chain": "{segment}"}},
169
+ {{ cartoon: {{ colorscheme:"whiteCarbon" }} }}
170
+ );
171
+
172
+ {high_score_script}
173
+
174
+ // Add hover functionality
175
+ viewer.setHoverable(
176
+ {{}},
177
+ true,
178
+ function(atom, viewer, event, container) {{
179
+ if (!atom.label) {{
180
+ atom.label = viewer.addLabel(
181
+ atom.resn + ":" + atom.atom,
182
+ {{
183
+ position: atom,
184
+ backgroundColor: 'mintcream',
185
+ fontColor: 'black',
186
+ fontSize: 12,
187
+ padding: 2
188
+ }}
189
+ );
190
+ }}
191
+ }},
192
+ function(atom, viewer) {{
193
+ if (atom.label) {{
194
+ viewer.removeLabel(atom.label);
195
+ delete atom.label;
196
+ }}
197
+ }}
198
+ );
199
+
200
+ viewer.zoomTo();
201
+ viewer.render();
202
+ viewer.zoom(0.8, 2000);
203
+ }});
204
+ </script>
205
+ </body>
206
+ </html>
207
+ """
208
+
209
+ # Return the HTML content within an iframe safely encoded for special characters
210
+ return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
211
+
212
+ reps = [
213
+ {
214
+ "model": 0,
215
+ "style": "cartoon",
216
+ "color": "whiteCarbon",
217
+ "residue_range": "",
218
+ "around": 0,
219
+ "byres": False,
220
+ }
221
+ ]
222
 
223
  # Gradio UI
224
  with gr.Blocks() as demo:
225
+ gr.Markdown("# Protein Binding Site Prediction (Random Scores)")
226
+ with gr.Row():
227
+ pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
228
+ visualize_btn = gr.Button("Visualize Structure")
229
+
230
+ molecule_output2 = Molecule3D(label="Protein Structure", reps=reps)
231
 
232
  with gr.Row():
233
+ pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
234
+ segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
235
+ prediction_btn = gr.Button("Predict Random Binding Site Scores")
236
+
237
+ molecule_output = gr.HTML(label="Protein Structure")
 
 
 
 
 
238
  predictions_output = gr.Textbox(label="Binding Site Predictions")
239
  download_output = gr.File(label="Download Predictions")
240
+
241
+ visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)
242
+
243
+ prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])
244
+
 
 
 
245
  gr.Markdown("## Examples")
246
  gr.Examples(
247
  examples=[
248
+ ["2IWI", "A"],
249
+ ["7RPZ", "B"],
250
+ ["3TJN", "C"]
251
  ],
252
+ inputs=[pdb_input, segment_input],
253
  outputs=[predictions_output, molecule_output, download_output]
254
  )
255
 
.ipynb_checkpoints/requirements-checkpoint.txt CHANGED
@@ -10,5 +10,4 @@ sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
  gradio_molecule3d
13
- biopython>=1.81
14
- matplotlib
 
10
  huggingface_hub>=0.15.0
11
  requests
12
  gradio_molecule3d
13
+ biopython>=1.81
 
.ipynb_checkpoints/test-checkpoint.ipynb ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "id": "d2208d17-47b6-4ff1-b6b6-ba09a9d490c7",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7864\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": []
33
+ },
34
+ "execution_count": 5,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "import gradio as gr\n",
41
+ "import requests\n",
42
+ "from Bio.PDB import PDBParser\n",
43
+ "from gradio_molecule3d import Molecule3D\n",
44
+ "import numpy as np\n",
45
+ "\n",
46
+ "# Function to fetch a PDB file from RCSB PDB\n",
47
+ "def fetch_pdb(pdb_id):\n",
48
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
49
+ " pdb_path = f'{pdb_id}.pdb'\n",
50
+ " response = requests.get(pdb_url)\n",
51
+ " if response.status_code == 200:\n",
52
+ " with open(pdb_path, 'wb') as f:\n",
53
+ " f.write(response.content)\n",
54
+ " return pdb_path\n",
55
+ " else:\n",
56
+ " return None\n",
57
+ "\n",
58
+ "# Function to process the PDB file and return random predictions\n",
59
+ "def process_pdb(pdb_id, segment):\n",
60
+ " pdb_path = fetch_pdb(pdb_id)\n",
61
+ " if not pdb_path:\n",
62
+ " return \"Failed to fetch PDB file\", None, None\n",
63
+ "\n",
64
+ " parser = PDBParser(QUIET=True)\n",
65
+ " structure = parser.get_structure('protein', pdb_path)\n",
66
+ " \n",
67
+ " try:\n",
68
+ " chain = structure[0][segment]\n",
69
+ " except KeyError:\n",
70
+ " return \"Invalid Chain ID\", None, None\n",
71
+ "\n",
72
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
73
+ " random_scores = np.random.rand(len(sequence))\n",
74
+ "\n",
75
+ " result_str = \"\\n\".join(\n",
76
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
77
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
78
+ " )\n",
79
+ "\n",
80
+ " # Save the predictions to a file\n",
81
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
82
+ " with open(prediction_file, \"w\") as f:\n",
83
+ " f.write(result_str)\n",
84
+ " \n",
85
+ " return result_str, pdb_path, prediction_file\n",
86
+ "\n",
87
+ "#reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
88
+ "\n",
89
+ "reps = [\n",
90
+ " {\n",
91
+ " \"model\": 0,\n",
92
+ " \"style\": \"cartoon\",\n",
93
+ " \"color\": \"whiteCarbon\",\n",
94
+ " \"residue_range\": \"\",\n",
95
+ " \"around\": 0,\n",
96
+ " \"byres\": False,\n",
97
+ " },\n",
98
+ " {\n",
99
+ " \"model\": 0,\n",
100
+ " \"chain\": \"A\",\n",
101
+ " \"resname\": \"HIS\",\n",
102
+ " \"style\": \"stick\",\n",
103
+ " \"color\": \"red\"\n",
104
+ " }\n",
105
+ " ]\n",
106
+ "\n",
107
+ "\n",
108
+ "# Gradio UI\n",
109
+ "with gr.Blocks() as demo:\n",
110
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
111
+ "\n",
112
+ " with gr.Row():\n",
113
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
114
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
115
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
116
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
117
+ "\n",
118
+ " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
119
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
120
+ " download_output = gr.File(label=\"Download Predictions\")\n",
121
+ "\n",
122
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
123
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
124
+ "\n",
125
+ " gr.Markdown(\"## Examples\")\n",
126
+ " gr.Examples(\n",
127
+ " examples=[\n",
128
+ " [\"2IWI\", \"A\"],\n",
129
+ " [\"7RPZ\", \"B\"],\n",
130
+ " [\"3TJN\", \"C\"]\n",
131
+ " ],\n",
132
+ " inputs=[pdb_input, segment_input],\n",
133
+ " outputs=[predictions_output, molecule_output, download_output]\n",
134
+ " )\n",
135
+ "\n",
136
+ "demo.launch()"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "id": "bd50ff2e-ed03-498e-8af2-73c0fb8ea07e",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": []
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 4,
150
+ "id": "a1088e14-f09c-48ff-8632-cc4685306d7c",
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "name": "stdout",
155
+ "output_type": "stream",
156
+ "text": [
157
+ "* Running on local URL: http://127.0.0.1:7863\n",
158
+ "\n",
159
+ "To create a public link, set `share=True` in `launch()`.\n"
160
+ ]
161
+ },
162
+ {
163
+ "data": {
164
+ "text/html": [
165
+ "<div><iframe src=\"http://127.0.0.1:7863/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
166
+ ],
167
+ "text/plain": [
168
+ "<IPython.core.display.HTML object>"
169
+ ]
170
+ },
171
+ "metadata": {},
172
+ "output_type": "display_data"
173
+ }
174
+ ],
175
+ "source": [
176
+ "import gradio as gr\n",
177
+ "from gradio_molecule3d import Molecule3D\n",
178
+ "\n",
179
+ "\n",
180
+ "example = Molecule3D().example_value()\n",
181
+ "\n",
182
+ "\n",
183
+ "reps = [\n",
184
+ " {\n",
185
+ " \"model\": 0,\n",
186
+ " \"style\": \"cartoon\",\n",
187
+ " \"color\": \"whiteCarbon\",\n",
188
+ " \"residue_range\": \"\",\n",
189
+ " \"around\": 0,\n",
190
+ " \"byres\": False,\n",
191
+ " },\n",
192
+ " {\n",
193
+ " \"model\": 0,\n",
194
+ " \"chain\": \"A\",\n",
195
+ " \"resname\": \"HIS\",\n",
196
+ " \"style\": \"stick\",\n",
197
+ " \"color\": \"red\"\n",
198
+ " }\n",
199
+ " ]\n",
200
+ "\n",
201
+ "\n",
202
+ "\n",
203
+ "def predict(x):\n",
204
+ " print(\"predict function\", x)\n",
205
+ " print(x.name)\n",
206
+ " return x\n",
207
+ "\n",
208
+ "with gr.Blocks() as demo:\n",
209
+ " gr.Markdown(\"# Molecule3D\")\n",
210
+ " inp = Molecule3D(label=\"Molecule3D\", reps=reps)\n",
211
+ " out = Molecule3D(label=\"Output\", reps=reps)\n",
212
+ "\n",
213
+ " btn = gr.Button(\"Predict\")\n",
214
+ " gr.Markdown(\"\"\" \n",
215
+ " You can configure the default rendering of the molecule by adding a list of representations\n",
216
+ " <pre>\n",
217
+ " reps = [\n",
218
+ " {\n",
219
+ " \"model\": 0,\n",
220
+ " \"style\": \"cartoon\",\n",
221
+ " \"color\": \"whiteCarbon\",\n",
222
+ " \"residue_range\": \"\",\n",
223
+ " \"around\": 0,\n",
224
+ " \"byres\": False,\n",
225
+ " },\n",
226
+ " {\n",
227
+ " \"model\": 0,\n",
228
+ " \"chain\": \"A\",\n",
229
+ " \"resname\": \"HIS\",\n",
230
+ " \"style\": \"stick\",\n",
231
+ " \"color\": \"red\"\n",
232
+ " }\n",
233
+ " ]\n",
234
+ " </pre>\n",
235
+ " \"\"\")\n",
236
+ " btn.click(predict, inputs=inp, outputs=out)\n",
237
+ "\n",
238
+ "\n",
239
+ "if __name__ == \"__main__\":\n",
240
+ " demo.launch()"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "id": "d27cc368-26a0-42c2-a68a-8833de7bb4a0",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": []
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 8,
254
+ "id": "cdf7fd26-0464-40d9-9107-71c29dbcaef8",
255
+ "metadata": {},
256
+ "outputs": [
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "* Running on local URL: http://127.0.0.1:7867\n",
262
+ "\n",
263
+ "To create a public link, set `share=True` in `launch()`.\n"
264
+ ]
265
+ },
266
+ {
267
+ "data": {
268
+ "text/html": [
269
+ "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
270
+ ],
271
+ "text/plain": [
272
+ "<IPython.core.display.HTML object>"
273
+ ]
274
+ },
275
+ "metadata": {},
276
+ "output_type": "display_data"
277
+ },
278
+ {
279
+ "data": {
280
+ "text/plain": []
281
+ },
282
+ "execution_count": 8,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ },
286
+ {
287
+ "name": "stderr",
288
+ "output_type": "stream",
289
+ "text": [
290
+ "/var/folders/tm/ym2tckv54b96ws82y3b7cqhh0000gn/T/ipykernel_11794/4072855226.py:39: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
291
+ " colors = [cm.get_cmap('coolwarm')(score)[:3] for score in normalized_scores]\n",
292
+ "Traceback (most recent call last):\n",
293
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/queueing.py\", line 622, in process_events\n",
294
+ " response = await route_utils.call_process_api(\n",
295
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
296
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/route_utils.py\", line 323, in call_process_api\n",
297
+ " output = await app.get_blocks().process_api(\n",
298
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
299
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 2024, in process_api\n",
300
+ " data = await self.postprocess_data(block_fn, result[\"prediction\"], state)\n",
301
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
302
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio/blocks.py\", line 1830, in postprocess_data\n",
303
+ " prediction_value = block.postprocess(prediction_value)\n",
304
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
305
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/site-packages/gradio_molecule3d/molecule3d.py\", line 210, in postprocess\n",
306
+ " orig_name=Path(file).name,\n",
307
+ " ^^^^^^^^^^\n",
308
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/pathlib.py\", line 1162, in __init__\n",
309
+ " super().__init__(*args)\n",
310
+ " File \"/Users/thorben_froehlking/anaconda3/envs/LLM/lib/python3.12/pathlib.py\", line 373, in __init__\n",
311
+ " raise TypeError(\n",
312
+ "TypeError: argument should be a str or an os.PathLike object where __fspath__ returns a str, not 'dict'\n"
313
+ ]
314
+ }
315
+ ],
316
+ "source": [
317
+ "import gradio as gr\n",
318
+ "import requests\n",
319
+ "from Bio.PDB import PDBParser\n",
320
+ "from gradio_molecule3d import Molecule3D\n",
321
+ "import numpy as np\n",
322
+ "from matplotlib import cm\n",
323
+ "\n",
324
+ "# Function to fetch a PDB file from RCSB PDB\n",
325
+ "def fetch_pdb(pdb_id):\n",
326
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
327
+ " pdb_path = f'{pdb_id}.pdb'\n",
328
+ " response = requests.get(pdb_url)\n",
329
+ " if response.status_code == 200:\n",
330
+ " with open(pdb_path, 'wb') as f:\n",
331
+ " f.write(response.content)\n",
332
+ " return pdb_path\n",
333
+ " else:\n",
334
+ " return None\n",
335
+ "\n",
336
+ "# Function to process the PDB file and return random predictions\n",
337
+ "def process_pdb(pdb_id, segment):\n",
338
+ " pdb_path = fetch_pdb(pdb_id)\n",
339
+ " if not pdb_path:\n",
340
+ " return \"Failed to fetch PDB file\", None, None, None\n",
341
+ "\n",
342
+ " parser = PDBParser(QUIET=True)\n",
343
+ " structure = parser.get_structure('protein', pdb_path)\n",
344
+ "\n",
345
+ " try:\n",
346
+ " chain = structure[0][segment]\n",
347
+ " except KeyError:\n",
348
+ " return \"Invalid Chain ID\", None, None, None\n",
349
+ "\n",
350
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
351
+ " random_scores = np.random.rand(len(sequence))\n",
352
+ "\n",
353
+ " # Normalize scores for coloring (0 = blue, 1 = red)\n",
354
+ " normalized_scores = (random_scores - np.min(random_scores)) / (np.max(random_scores) - np.min(random_scores))\n",
355
+ " colors = [cm.get_cmap('coolwarm')(score)[:3] for score in normalized_scores]\n",
356
+ " hex_colors = [f'#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}' for r, g, b in colors]\n",
357
+ "\n",
358
+ " # Result string and representation\n",
359
+ " result_str = \"\\n\".join(\n",
360
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
361
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
362
+ " )\n",
363
+ "\n",
364
+ " # Representation for the protein structure\n",
365
+ " reps = [\n",
366
+ " {\n",
367
+ " \"model\": 0,\n",
368
+ " \"style\": \"cartoon\",\n",
369
+ " \"color\": \"whiteCarbon\"\n",
370
+ " }\n",
371
+ " ] + [\n",
372
+ " {\n",
373
+ " \"model\": 0,\n",
374
+ " \"style\": \"cartoon\",\n",
375
+ " \"residue_index\": i,\n",
376
+ " \"color\": color\n",
377
+ " }\n",
378
+ " for i, color in enumerate(hex_colors)\n",
379
+ " ]\n",
380
+ "\n",
381
+ " # Save the predictions to a file\n",
382
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
383
+ " with open(prediction_file, \"w\") as f:\n",
384
+ " f.write(result_str)\n",
385
+ " \n",
386
+ " return result_str, reps, prediction_file\n",
387
+ "\n",
388
+ "# Gradio UI\n",
389
+ "with gr.Blocks() as demo:\n",
390
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
391
+ "\n",
392
+ " with gr.Row():\n",
393
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
394
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
395
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
396
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
397
+ "\n",
398
+ " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
399
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
400
+ " download_output = gr.File(label=\"Download Predictions\")\n",
401
+ "\n",
402
+ " prediction_btn.click(\n",
403
+ " fn=process_pdb,\n",
404
+ " inputs=[pdb_input, segment_input],\n",
405
+ " outputs=[predictions_output, molecule_output, download_output]\n",
406
+ " )\n",
407
+ "\n",
408
+ " gr.Markdown(\"## Examples\")\n",
409
+ " gr.Examples(\n",
410
+ " examples=[\n",
411
+ " [\"2IWI\", \"A\"],\n",
412
+ " [\"7RPZ\", \"B\"],\n",
413
+ " [\"3TJN\", \"C\"]\n",
414
+ " ],\n",
415
+ " inputs=[pdb_input, segment_input],\n",
416
+ " outputs=[predictions_output, molecule_output, download_output]\n",
417
+ " )\n",
418
+ "\n",
419
+ "demo.launch()"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "id": "ee215c16-a1fb-450f-bb93-37aaee6fb3f1",
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": []
429
+ }
430
+ ],
431
+ "metadata": {
432
+ "kernelspec": {
433
+ "display_name": "Python (LLM)",
434
+ "language": "python",
435
+ "name": "llm"
436
+ },
437
+ "language_info": {
438
+ "codemirror_mode": {
439
+ "name": "ipython",
440
+ "version": 3
441
+ },
442
+ "file_extension": ".py",
443
+ "mimetype": "text/x-python",
444
+ "name": "python",
445
+ "nbconvert_exporter": "python",
446
+ "pygments_lexer": "ipython3",
447
+ "version": "3.12.7"
448
+ }
449
+ },
450
+ "nbformat": 4,
451
+ "nbformat_minor": 5
452
+ }
.ipynb_checkpoints/test2-checkpoint.ipynb ADDED
@@ -0,0 +1,1193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "f3b7f6b0-6685-4a5c-9529-45e0ca905a3b",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7860\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": []
33
+ },
34
+ "execution_count": 2,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "import gradio as gr\n",
41
+ "import requests\n",
42
+ "from Bio.PDB import PDBParser\n",
43
+ "import numpy as np\n",
44
+ "import os\n",
45
+ "from gradio_molecule3d import Molecule3D\n",
46
+ "\n",
47
+ "def read_mol(pdb_path):\n",
48
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
49
+ " with open(pdb_path, 'r') as f:\n",
50
+ " return f.read()\n",
51
+ "\n",
52
+ "def fetch_pdb(pdb_id):\n",
53
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
54
+ " pdb_path = f'{pdb_id}.pdb'\n",
55
+ " response = requests.get(pdb_url)\n",
56
+ " if response.status_code == 200:\n",
57
+ " with open(pdb_path, 'wb') as f:\n",
58
+ " f.write(response.content)\n",
59
+ " return pdb_path\n",
60
+ " else:\n",
61
+ " return None\n",
62
+ "\n",
63
+ "def process_pdb(pdb_id, segment):\n",
64
+ " pdb_path = fetch_pdb(pdb_id)\n",
65
+ " if not pdb_path:\n",
66
+ " return \"Failed to fetch PDB file\", None, None\n",
67
+ " \n",
68
+ " parser = PDBParser(QUIET=1)\n",
69
+ " structure = parser.get_structure('protein', pdb_path)\n",
70
+ " \n",
71
+ " try:\n",
72
+ " chain = structure[0][segment]\n",
73
+ " except KeyError:\n",
74
+ " return \"Invalid Chain ID\", None, None\n",
75
+ " \n",
76
+ " # Comprehensive amino acid mapping\n",
77
+ " aa_dict = {\n",
78
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
79
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
80
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
81
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
82
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
83
+ " }\n",
84
+ " \n",
85
+ " # Exclude non-amino acid residues\n",
86
+ " sequence = [\n",
87
+ " residue for residue in chain \n",
88
+ " if residue.get_resname().strip() in aa_dict\n",
89
+ " ]\n",
90
+ " \n",
91
+ " random_scores = np.random.rand(len(sequence))\n",
92
+ " result_str = \"\\n\".join(\n",
93
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
94
+ " for res, score in zip(sequence, random_scores)\n",
95
+ " )\n",
96
+ " \n",
97
+ " # Save the predictions to a file\n",
98
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
99
+ " with open(prediction_file, \"w\") as f:\n",
100
+ " f.write(result_str)\n",
101
+ " \n",
102
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
103
+ "\n",
104
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
105
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
106
+ " \n",
107
+ " # Prepare high-scoring residues script if scores are provided\n",
108
+ " high_score_script = \"\"\n",
109
+ " if scores is not None:\n",
110
+ " high_score_script = \"\"\"\n",
111
+ " // Reset all styles first\n",
112
+ " viewer.getModel(0).setStyle({}, {});\n",
113
+ " \n",
114
+ " // Show only the selected chain\n",
115
+ " viewer.getModel(0).setStyle(\n",
116
+ " {\"chain\": \"%s\"}, \n",
117
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
118
+ " );\n",
119
+ " \n",
120
+ " // Highlight high-scoring residues only for the selected chain\n",
121
+ " let highScoreResidues = [%s];\n",
122
+ " viewer.getModel(0).setStyle(\n",
123
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
124
+ " {\"stick\": {\"color\": \"red\"}}\n",
125
+ " );\n",
126
+ " \"\"\" % (segment, \n",
127
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
128
+ " segment)\n",
129
+ " \n",
130
+ " html_content = f\"\"\"\n",
131
+ " <!DOCTYPE html>\n",
132
+ " <html>\n",
133
+ " <head> \n",
134
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
135
+ " <style>\n",
136
+ " .mol-container {{\n",
137
+ " width: 100%;\n",
138
+ " height: 700px;\n",
139
+ " position: relative;\n",
140
+ " }}\n",
141
+ " </style>\n",
142
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
143
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
144
+ " </head>\n",
145
+ " <body>\n",
146
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
147
+ " <script>\n",
148
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
149
+ " $(document).ready(function () {{\n",
150
+ " let element = $(\"#container\");\n",
151
+ " let config = {{ backgroundColor: \"white\" }};\n",
152
+ " let viewer = $3Dmol.createViewer(element, config);\n",
153
+ " viewer.addModel(pdb, \"pdb\");\n",
154
+ " \n",
155
+ " // Reset all styles and show only selected chain\n",
156
+ " viewer.getModel(0).setStyle(\n",
157
+ " {{\"chain\": \"{segment}\"}}, \n",
158
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
159
+ " );\n",
160
+ " \n",
161
+ " {high_score_script}\n",
162
+ " \n",
163
+ " viewer.zoomTo();\n",
164
+ " viewer.render();\n",
165
+ " viewer.zoom(0.8, 2000);\n",
166
+ " }});\n",
167
+ " </script>\n",
168
+ " </body>\n",
169
+ " </html>\n",
170
+ " \"\"\"\n",
171
+ " \n",
172
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
173
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
174
+ "\n",
175
+ "reps = [\n",
176
+ " {\n",
177
+ " \"model\": 0,\n",
178
+ " \"style\": \"cartoon\",\n",
179
+ " \"color\": \"whiteCarbon\",\n",
180
+ " \"residue_range\": \"\",\n",
181
+ " \"around\": 0,\n",
182
+ " \"byres\": False,\n",
183
+ " }\n",
184
+ " ]\n",
185
+ "# Gradio UI\n",
186
+ "with gr.Blocks() as demo:\n",
187
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
188
+ " with gr.Row():\n",
189
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
190
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
191
+ "\n",
192
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
193
+ "\n",
194
+ " with gr.Row():\n",
195
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
196
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
197
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
198
+ "\n",
199
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
200
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
201
+ " download_output = gr.File(label=\"Download Predictions\")\n",
202
+ " \n",
203
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
204
+ " \n",
205
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
206
+ " \n",
207
+ " gr.Markdown(\"## Examples\")\n",
208
+ " gr.Examples(\n",
209
+ " examples=[\n",
210
+ " [\"2IWI\", \"A\"],\n",
211
+ " [\"7RPZ\", \"B\"],\n",
212
+ " [\"3TJN\", \"C\"]\n",
213
+ " ],\n",
214
+ " inputs=[pdb_input, segment_input],\n",
215
+ " outputs=[predictions_output, molecule_output, download_output]\n",
216
+ " )\n",
217
+ "\n",
218
+ "demo.launch()"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 6,
224
+ "id": "28f8f28c-48d3-4e35-9766-3de9882179b5",
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "* Running on local URL: http://127.0.0.1:7864\n",
232
+ "\n",
233
+ "To create a public link, set `share=True` in `launch()`.\n"
234
+ ]
235
+ },
236
+ {
237
+ "data": {
238
+ "text/html": [
239
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
240
+ ],
241
+ "text/plain": [
242
+ "<IPython.core.display.HTML object>"
243
+ ]
244
+ },
245
+ "metadata": {},
246
+ "output_type": "display_data"
247
+ },
248
+ {
249
+ "data": {
250
+ "text/plain": []
251
+ },
252
+ "execution_count": 6,
253
+ "metadata": {},
254
+ "output_type": "execute_result"
255
+ }
256
+ ],
257
+ "source": [
258
+ "import gradio as gr\n",
259
+ "import requests\n",
260
+ "from Bio.PDB import PDBParser\n",
261
+ "import numpy as np\n",
262
+ "import os\n",
263
+ "from gradio_molecule3d import Molecule3D\n",
264
+ "\n",
265
+ "def read_mol(pdb_path):\n",
266
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
267
+ " with open(pdb_path, 'r') as f:\n",
268
+ " return f.read()\n",
269
+ "\n",
270
+ "def fetch_pdb(pdb_id):\n",
271
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
272
+ " pdb_path = f'{pdb_id}.pdb'\n",
273
+ " response = requests.get(pdb_url)\n",
274
+ " if response.status_code == 200:\n",
275
+ " with open(pdb_path, 'wb') as f:\n",
276
+ " f.write(response.content)\n",
277
+ " return pdb_path\n",
278
+ " else:\n",
279
+ " return None\n",
280
+ "\n",
281
+ "def process_pdb(pdb_id, segment):\n",
282
+ " pdb_path = fetch_pdb(pdb_id)\n",
283
+ " if not pdb_path:\n",
284
+ " return \"Failed to fetch PDB file\", None, None\n",
285
+ " \n",
286
+ " parser = PDBParser(QUIET=1)\n",
287
+ " structure = parser.get_structure('protein', pdb_path)\n",
288
+ " \n",
289
+ " try:\n",
290
+ " chain = structure[0][segment]\n",
291
+ " except KeyError:\n",
292
+ " return \"Invalid Chain ID\", None, None\n",
293
+ " \n",
294
+ " # Comprehensive amino acid mapping\n",
295
+ " aa_dict = {\n",
296
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
297
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
298
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
299
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
300
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
301
+ " }\n",
302
+ " \n",
303
+ " # Exclude non-amino acid residues\n",
304
+ " sequence = [\n",
305
+ " residue for residue in chain \n",
306
+ " if residue.get_resname().strip() in aa_dict\n",
307
+ " ]\n",
308
+ " \n",
309
+ " random_scores = np.random.rand(len(sequence))\n",
310
+ " result_str = \"\\n\".join(\n",
311
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
312
+ " for res, score in zip(sequence, random_scores)\n",
313
+ " )\n",
314
+ " \n",
315
+ " # Save the predictions to a file\n",
316
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
317
+ " with open(prediction_file, \"w\") as f:\n",
318
+ " f.write(result_str)\n",
319
+ " \n",
320
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
321
+ "\n",
322
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
323
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
324
+ " \n",
325
+ " # Prepare high-scoring residues script if scores are provided\n",
326
+ " high_score_script = \"\"\n",
327
+ " if scores is not None:\n",
328
+ " high_score_script = \"\"\"\n",
329
+ " // Reset all styles first\n",
330
+ " viewer.getModel(0).setStyle({}, {});\n",
331
+ " \n",
332
+ " // Show only the selected chain\n",
333
+ " viewer.getModel(0).setStyle(\n",
334
+ " {\"chain\": \"%s\"}, \n",
335
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
336
+ " );\n",
337
+ " \n",
338
+ " // Highlight high-scoring residues only for the selected chain\n",
339
+ " let highScoreResidues = [%s];\n",
340
+ " viewer.getModel(0).setStyle(\n",
341
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
342
+ " {\"stick\": {\"color\": \"red\"}}\n",
343
+ " );\n",
344
+ " \"\"\" % (segment, \n",
345
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
346
+ " segment)\n",
347
+ " \n",
348
+ " html_content = f\"\"\"\n",
349
+ " <!DOCTYPE html>\n",
350
+ " <html>\n",
351
+ " <head> \n",
352
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
353
+ " <style>\n",
354
+ " .mol-container {{\n",
355
+ " width: 100%;\n",
356
+ " height: 700px;\n",
357
+ " position: relative;\n",
358
+ " }}\n",
359
+ " </style>\n",
360
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
361
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
362
+ " </head>\n",
363
+ " <body>\n",
364
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
365
+ " <script>\n",
366
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
367
+ " $(document).ready(function () {{\n",
368
+ " let element = $(\"#container\");\n",
369
+ " let config = {{ backgroundColor: \"white\" }};\n",
370
+ " let viewer = $3Dmol.createViewer(element, config);\n",
371
+ " viewer.addModel(pdb, \"pdb\");\n",
372
+ " \n",
373
+ " // Reset all styles and show only selected chain\n",
374
+ " viewer.getModel(0).setStyle(\n",
375
+ " {{\"chain\": \"{segment}\"}}, \n",
376
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
377
+ " );\n",
378
+ " \n",
379
+ " {high_score_script}\n",
380
+ " \n",
381
+ " // Add hover functionality\n",
382
+ " viewer.setHoverable(\n",
383
+ " {{}}, \n",
384
+ " true, \n",
385
+ " function(atom, viewer, event, container) {{\n",
386
+ " if (!atom.label) {{\n",
387
+ " atom.label = viewer.addLabel(\n",
388
+ " atom.resn + \":\" + atom.atom, \n",
389
+ " {{\n",
390
+ " position: atom, \n",
391
+ " backgroundColor: 'mintcream', \n",
392
+ " fontColor: 'black',\n",
393
+ " fontSize: 12,\n",
394
+ " padding: 2\n",
395
+ " }}\n",
396
+ " );\n",
397
+ " }}\n",
398
+ " }},\n",
399
+ " function(atom, viewer) {{\n",
400
+ " if (atom.label) {{\n",
401
+ " viewer.removeLabel(atom.label);\n",
402
+ " delete atom.label;\n",
403
+ " }}\n",
404
+ " }}\n",
405
+ " );\n",
406
+ " \n",
407
+ " viewer.zoomTo();\n",
408
+ " viewer.render();\n",
409
+ " viewer.zoom(0.8, 2000);\n",
410
+ " }});\n",
411
+ " </script>\n",
412
+ " </body>\n",
413
+ " </html>\n",
414
+ " \"\"\"\n",
415
+ " \n",
416
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
417
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
418
+ "\n",
419
+ "reps = [\n",
420
+ " {\n",
421
+ " \"model\": 0,\n",
422
+ " \"style\": \"cartoon\",\n",
423
+ " \"color\": \"whiteCarbon\",\n",
424
+ " \"residue_range\": \"\",\n",
425
+ " \"around\": 0,\n",
426
+ " \"byres\": False,\n",
427
+ " }\n",
428
+ " ]\n",
429
+ "\n",
430
+ "# Gradio UI\n",
431
+ "with gr.Blocks() as demo:\n",
432
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
433
+ " with gr.Row():\n",
434
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
435
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
436
+ "\n",
437
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
438
+ "\n",
439
+ " with gr.Row():\n",
440
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
441
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
442
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
443
+ "\n",
444
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
445
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
446
+ " download_output = gr.File(label=\"Download Predictions\")\n",
447
+ " \n",
448
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
449
+ " \n",
450
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
451
+ " \n",
452
+ " gr.Markdown(\"## Examples\")\n",
453
+ " gr.Examples(\n",
454
+ " examples=[\n",
455
+ " [\"2IWI\", \"A\"],\n",
456
+ " [\"7RPZ\", \"B\"],\n",
457
+ " [\"3TJN\", \"C\"]\n",
458
+ " ],\n",
459
+ " inputs=[pdb_input, segment_input],\n",
460
+ " outputs=[predictions_output, molecule_output, download_output]\n",
461
+ " )\n",
462
+ "\n",
463
+ "demo.launch()"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "id": "517a2fe7-419f-4d0b-a9ed-62a22c1c1284",
470
+ "metadata": {},
471
+ "outputs": [],
472
+ "source": []
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 11,
477
+ "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
+ "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "* Running on local URL: http://127.0.0.1:7867\n",
485
+ "\n",
486
+ "To create a public link, set `share=True` in `launch()`.\n"
487
+ ]
488
+ },
489
+ {
490
+ "data": {
491
+ "text/html": [
492
+ "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
+ ],
494
+ "text/plain": [
495
+ "<IPython.core.display.HTML object>"
496
+ ]
497
+ },
498
+ "metadata": {},
499
+ "output_type": "display_data"
500
+ },
501
+ {
502
+ "data": {
503
+ "text/plain": []
504
+ },
505
+ "execution_count": 11,
506
+ "metadata": {},
507
+ "output_type": "execute_result"
508
+ }
509
+ ],
510
+ "source": [
511
+ "import gradio as gr\n",
512
+ "import requests\n",
513
+ "from Bio.PDB import PDBParser\n",
514
+ "import numpy as np\n",
515
+ "import os\n",
516
+ "from gradio_molecule3d import Molecule3D\n",
517
+ "\n",
518
+ "def read_mol(pdb_path):\n",
519
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
520
+ " with open(pdb_path, 'r') as f:\n",
521
+ " return f.read()\n",
522
+ "\n",
523
+ "def fetch_pdb(pdb_id):\n",
524
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
525
+ " pdb_path = f'{pdb_id}.pdb'\n",
526
+ " response = requests.get(pdb_url)\n",
527
+ " if response.status_code == 200:\n",
528
+ " with open(pdb_path, 'wb') as f:\n",
529
+ " f.write(response.content)\n",
530
+ " return pdb_path\n",
531
+ " else:\n",
532
+ " return None\n",
533
+ "\n",
534
+ "def process_pdb(pdb_id, segment):\n",
535
+ " pdb_path = fetch_pdb(pdb_id)\n",
536
+ " if not pdb_path:\n",
537
+ " return \"Failed to fetch PDB file\", None, None\n",
538
+ " \n",
539
+ " parser = PDBParser(QUIET=1)\n",
540
+ " structure = parser.get_structure('protein', pdb_path)\n",
541
+ " \n",
542
+ " try:\n",
543
+ " chain = structure[0][segment]\n",
544
+ " except KeyError:\n",
545
+ " return \"Invalid Chain ID\", None, None\n",
546
+ " \n",
547
+ " # Comprehensive amino acid mapping\n",
548
+ " aa_dict = {\n",
549
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
550
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
551
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
552
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
553
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
554
+ " }\n",
555
+ " \n",
556
+ " # Exclude non-amino acid residues\n",
557
+ " sequence = [\n",
558
+ " residue for residue in chain \n",
559
+ " if residue.get_resname().strip() in aa_dict\n",
560
+ " ]\n",
561
+ " \n",
562
+ " random_scores = np.random.rand(len(sequence))\n",
563
+ " result_str = \"\\n\".join(\n",
564
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
565
+ " for res, score in zip(sequence, random_scores)\n",
566
+ " )\n",
567
+ " \n",
568
+ " # Save the predictions to a file\n",
569
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
570
+ " with open(prediction_file, \"w\") as f:\n",
571
+ " f.write(result_str)\n",
572
+ " \n",
573
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
574
+ "\n",
575
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
576
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
577
+ " \n",
578
+ " # Prepare high-scoring residues script if scores are provided\n",
579
+ " high_score_script = \"\"\n",
580
+ " if scores is not None:\n",
581
+ " high_score_script = \"\"\"\n",
582
+ " // Reset all styles first\n",
583
+ " viewer.getModel(0).setStyle({}, {});\n",
584
+ " \n",
585
+ " // Show only the selected chain\n",
586
+ " viewer.getModel(0).setStyle(\n",
587
+ " {\"chain\": \"%s\"}, \n",
588
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
589
+ " );\n",
590
+ " \n",
591
+ " // Highlight high-scoring residues only for the selected chain\n",
592
+ " let highScoreResidues = [%s];\n",
593
+ " viewer.getModel(0).setStyle(\n",
594
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
595
+ " {\"stick\": {\"color\": \"red\"}}\n",
596
+ " );\n",
597
+ "\n",
598
+ " // Highlight high-scoring residues only for the selected chain\n",
599
+ " let highScoreResidues2 = [%s];\n",
600
+ " viewer.getModel(0).setStyle(\n",
601
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
602
+ " {\"stick\": {\"color\": \"orange\"}}\n",
603
+ " );\n",
604
+ " \"\"\" % (segment, \n",
605
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
606
+ " segment,\n",
607
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
608
+ " segment)\n",
609
+ " \n",
610
+ " html_content = f\"\"\"\n",
611
+ " <!DOCTYPE html>\n",
612
+ " <html>\n",
613
+ " <head> \n",
614
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
615
+ " <style>\n",
616
+ " .mol-container {{\n",
617
+ " width: 100%;\n",
618
+ " height: 700px;\n",
619
+ " position: relative;\n",
620
+ " }}\n",
621
+ " </style>\n",
622
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
623
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
624
+ " </head>\n",
625
+ " <body>\n",
626
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
627
+ " <script>\n",
628
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
629
+ " $(document).ready(function () {{\n",
630
+ " let element = $(\"#container\");\n",
631
+ " let config = {{ backgroundColor: \"white\" }};\n",
632
+ " let viewer = $3Dmol.createViewer(element, config);\n",
633
+ " viewer.addModel(pdb, \"pdb\");\n",
634
+ " \n",
635
+ " // Reset all styles and show only selected chain\n",
636
+ " viewer.getModel(0).setStyle(\n",
637
+ " {{\"chain\": \"{segment}\"}}, \n",
638
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
639
+ " );\n",
640
+ " \n",
641
+ " {high_score_script}\n",
642
+ " \n",
643
+ " // Add hover functionality\n",
644
+ " viewer.setHoverable(\n",
645
+ " {{}}, \n",
646
+ " true, \n",
647
+ " function(atom, viewer, event, container) {{\n",
648
+ " if (!atom.label) {{\n",
649
+ " atom.label = viewer.addLabel(\n",
650
+ " atom.resn + \":\" + atom.atom, \n",
651
+ " {{\n",
652
+ " position: atom, \n",
653
+ " backgroundColor: 'mintcream', \n",
654
+ " fontColor: 'black',\n",
655
+ " fontSize: 12,\n",
656
+ " padding: 2\n",
657
+ " }}\n",
658
+ " );\n",
659
+ " }}\n",
660
+ " }},\n",
661
+ " function(atom, viewer) {{\n",
662
+ " if (atom.label) {{\n",
663
+ " viewer.removeLabel(atom.label);\n",
664
+ " delete atom.label;\n",
665
+ " }}\n",
666
+ " }}\n",
667
+ " );\n",
668
+ " \n",
669
+ " viewer.zoomTo();\n",
670
+ " viewer.render();\n",
671
+ " viewer.zoom(0.8, 2000);\n",
672
+ " }});\n",
673
+ " </script>\n",
674
+ " </body>\n",
675
+ " </html>\n",
676
+ " \"\"\"\n",
677
+ " \n",
678
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
679
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
680
+ "\n",
681
+ "reps = [\n",
682
+ " {\n",
683
+ " \"model\": 0,\n",
684
+ " \"style\": \"cartoon\",\n",
685
+ " \"color\": \"whiteCarbon\",\n",
686
+ " \"residue_range\": \"\",\n",
687
+ " \"around\": 0,\n",
688
+ " \"byres\": False,\n",
689
+ " }\n",
690
+ " ]\n",
691
+ "\n",
692
+ "# Gradio UI\n",
693
+ "with gr.Blocks() as demo:\n",
694
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
695
+ " with gr.Row():\n",
696
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
697
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
698
+ "\n",
699
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
700
+ "\n",
701
+ " with gr.Row():\n",
702
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
703
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
704
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
705
+ "\n",
706
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
707
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
708
+ " download_output = gr.File(label=\"Download Predictions\")\n",
709
+ " \n",
710
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
711
+ " \n",
712
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
713
+ " \n",
714
+ " gr.Markdown(\"## Examples\")\n",
715
+ " gr.Examples(\n",
716
+ " examples=[\n",
717
+ " [\"2IWI\", \"A\"],\n",
718
+ " [\"7RPZ\", \"B\"],\n",
719
+ " [\"3TJN\", \"C\"]\n",
720
+ " ],\n",
721
+ " inputs=[pdb_input, segment_input],\n",
722
+ " outputs=[predictions_output, molecule_output, download_output]\n",
723
+ " )\n",
724
+ "\n",
725
+ "demo.launch()"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": null,
731
+ "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
+ "metadata": {},
733
+ "outputs": [],
734
+ "source": []
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "execution_count": null,
739
+ "id": "5eca6754-4aa1-463f-881a-25d2a0d6bb5b",
740
+ "metadata": {},
741
+ "outputs": [],
742
+ "source": [
743
+ "import gradio as gr\n",
744
+ "import requests\n",
745
+ "from Bio.PDB import PDBParser\n",
746
+ "import numpy as np\n",
747
+ "import os\n",
748
+ "from gradio_molecule3d import Molecule3D\n",
749
+ "\n",
750
+ "\n",
751
+ "from model_loader import load_model\n",
752
+ "\n",
753
+ "import torch\n",
754
+ "import torch.nn as nn\n",
755
+ "import torch.nn.functional as F\n",
756
+ "from torch.utils.data import DataLoader\n",
757
+ "\n",
758
+ "import re\n",
759
+ "import pandas as pd\n",
760
+ "import copy\n",
761
+ "\n",
762
+ "import transformers, datasets\n",
763
+ "from transformers import AutoTokenizer\n",
764
+ "from transformers import DataCollatorForTokenClassification\n",
765
+ "\n",
766
+ "from datasets import Dataset\n",
767
+ "\n",
768
+ "from scipy.special import expit\n",
769
+ "\n",
770
+ "# Load model and move to device\n",
771
+ "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
772
+ "max_length = 1500\n",
773
+ "model, tokenizer = load_model(checkpoint, max_length)\n",
774
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
775
+ "model.to(device)\n",
776
+ "model.eval()\n",
777
+ "\n",
778
+ "def normalize_scores(scores):\n",
779
+ " min_score = np.min(scores)\n",
780
+ " max_score = np.max(scores)\n",
781
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
782
+ " \n",
783
+ "def read_mol(pdb_path):\n",
784
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
785
+ " with open(pdb_path, 'r') as f:\n",
786
+ " return f.read()\n",
787
+ "\n",
788
+ "def fetch_pdb(pdb_id):\n",
789
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
790
+ " pdb_path = f'{pdb_id}.pdb'\n",
791
+ " response = requests.get(pdb_url)\n",
792
+ " if response.status_code == 200:\n",
793
+ " with open(pdb_path, 'wb') as f:\n",
794
+ " f.write(response.content)\n",
795
+ " return pdb_path\n",
796
+ " else:\n",
797
+ " return None\n",
798
+ "\n",
799
+ "def process_pdb(pdb_id, segment):\n",
800
+ " pdb_path = fetch_pdb(pdb_id)\n",
801
+ " if not pdb_path:\n",
802
+ " return \"Failed to fetch PDB file\", None, None\n",
803
+ " \n",
804
+ " parser = PDBParser(QUIET=1)\n",
805
+ " structure = parser.get_structure('protein', pdb_path)\n",
806
+ " \n",
807
+ " try:\n",
808
+ " chain = structure[0][segment]\n",
809
+ " except KeyError:\n",
810
+ " return \"Invalid Chain ID\", None, None\n",
811
+ " \n",
812
+ " # Comprehensive amino acid mapping\n",
813
+ " aa_dict = {\n",
814
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
815
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
816
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
817
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
818
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
819
+ " }\n",
820
+ " \n",
821
+ " # Exclude non-amino acid residues\n",
822
+ " sequence = [\n",
823
+ " residue for residue in chain \n",
824
+ " if residue.get_resname().strip() in aa_dict\n",
825
+ " ]\n",
826
+ " \n",
827
+ " # Prepare input for model prediction\n",
828
+ " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n",
829
+ " with torch.no_grad():\n",
830
+ " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n",
831
+ "\n",
832
+ " # Calculate scores and normalize them\n",
833
+ " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
834
+ " normalized_scores = normalize_scores(scores)\n",
835
+ "\n",
836
+ " result_str = \"\\n\".join(\n",
837
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
838
+ " for res, score in zip(sequence, normalized_scores)\n",
839
+ " )\n",
840
+ " \n",
841
+ " # Save the predictions to a file\n",
842
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
843
+ " with open(prediction_file, \"w\") as f:\n",
844
+ " f.write(result_str)\n",
845
+ " \n",
846
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
847
+ "\n",
848
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
849
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
850
+ " \n",
851
+ " # Prepare high-scoring residues script if scores are provided\n",
852
+ " high_score_script = \"\"\n",
853
+ " if scores is not None:\n",
854
+ " high_score_script = \"\"\"\n",
855
+ " // Reset all styles first\n",
856
+ " viewer.getModel(0).setStyle({}, {});\n",
857
+ " \n",
858
+ " // Show only the selected chain\n",
859
+ " viewer.getModel(0).setStyle(\n",
860
+ " {\"chain\": \"%s\"}, \n",
861
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
862
+ " );\n",
863
+ " \n",
864
+ " // Highlight high-scoring residues only for the selected chain\n",
865
+ " let highScoreResidues = [%s];\n",
866
+ " viewer.getModel(0).setStyle(\n",
867
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
868
+ " {\"stick\": {\"color\": \"red\"}}\n",
869
+ " );\n",
870
+ "\n",
871
+ " // Highlight high-scoring residues only for the selected chain\n",
872
+ " let highScoreResidues2 = [%s];\n",
873
+ " viewer.getModel(0).setStyle(\n",
874
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
875
+ " {\"stick\": {\"color\": \"orange\"}}\n",
876
+ " );\n",
877
+ " \"\"\" % (segment, \n",
878
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
879
+ " segment,\n",
880
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
881
+ " segment)\n",
882
+ " \n",
883
+ " html_content = f\"\"\"\n",
884
+ " <!DOCTYPE html>\n",
885
+ " <html>\n",
886
+ " <head> \n",
887
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
888
+ " <style>\n",
889
+ " .mol-container {{\n",
890
+ " width: 100%;\n",
891
+ " height: 700px;\n",
892
+ " position: relative;\n",
893
+ " }}\n",
894
+ " </style>\n",
895
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
896
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
897
+ " </head>\n",
898
+ " <body>\n",
899
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
900
+ " <script>\n",
901
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
902
+ " $(document).ready(function () {{\n",
903
+ " let element = $(\"#container\");\n",
904
+ " let config = {{ backgroundColor: \"white\" }};\n",
905
+ " let viewer = $3Dmol.createViewer(element, config);\n",
906
+ " viewer.addModel(pdb, \"pdb\");\n",
907
+ " \n",
908
+ " // Reset all styles and show only selected chain\n",
909
+ " viewer.getModel(0).setStyle(\n",
910
+ " {{\"chain\": \"{segment}\"}}, \n",
911
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
912
+ " );\n",
913
+ " \n",
914
+ " {high_score_script}\n",
915
+ " \n",
916
+ " // Add hover functionality\n",
917
+ " viewer.setHoverable(\n",
918
+ " {{}}, \n",
919
+ " true, \n",
920
+ " function(atom, viewer, event, container) {{\n",
921
+ " if (!atom.label) {{\n",
922
+ " atom.label = viewer.addLabel(\n",
923
+ " atom.resn + \":\" + atom.atom, \n",
924
+ " {{\n",
925
+ " position: atom, \n",
926
+ " backgroundColor: 'mintcream', \n",
927
+ " fontColor: 'black',\n",
928
+ " fontSize: 12,\n",
929
+ " padding: 2\n",
930
+ " }}\n",
931
+ " );\n",
932
+ " }}\n",
933
+ " }},\n",
934
+ " function(atom, viewer) {{\n",
935
+ " if (atom.label) {{\n",
936
+ " viewer.removeLabel(atom.label);\n",
937
+ " delete atom.label;\n",
938
+ " }}\n",
939
+ " }}\n",
940
+ " );\n",
941
+ " \n",
942
+ " viewer.zoomTo();\n",
943
+ " viewer.render();\n",
944
+ " viewer.zoom(0.8, 2000);\n",
945
+ " }});\n",
946
+ " </script>\n",
947
+ " </body>\n",
948
+ " </html>\n",
949
+ " \"\"\"\n",
950
+ " \n",
951
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
952
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
953
+ "\n",
954
+ "reps = [\n",
955
+ " {\n",
956
+ " \"model\": 0,\n",
957
+ " \"style\": \"cartoon\",\n",
958
+ " \"color\": \"whiteCarbon\",\n",
959
+ " \"residue_range\": \"\",\n",
960
+ " \"around\": 0,\n",
961
+ " \"byres\": False,\n",
962
+ " }\n",
963
+ " ]\n",
964
+ "\n",
965
+ "# Gradio UI\n",
966
+ "with gr.Blocks() as demo:\n",
967
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
968
+ " with gr.Row():\n",
969
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
970
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
971
+ "\n",
972
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
973
+ "\n",
974
+ " with gr.Row():\n",
975
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
976
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
977
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
978
+ "\n",
979
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
980
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
981
+ " download_output = gr.File(label=\"Download Predictions\")\n",
982
+ " \n",
983
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
984
+ " \n",
985
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
986
+ " \n",
987
+ " gr.Markdown(\"## Examples\")\n",
988
+ " gr.Examples(\n",
989
+ " examples=[\n",
990
+ " [\"2IWI\", \"A\"],\n",
991
+ " [\"7RPZ\", \"B\"],\n",
992
+ " [\"3TJN\", \"C\"]\n",
993
+ " ],\n",
994
+ " inputs=[pdb_input, segment_input],\n",
995
+ " outputs=[predictions_output, molecule_output, download_output]\n",
996
+ " )\n",
997
+ "\n",
998
+ "demo.launch()"
999
+ ]
1000
+ },
1001
+ {
1002
+ "cell_type": "code",
1003
+ "execution_count": null,
1004
+ "id": "95046d1c-ec7c-4e3e-8a98-1802cb09a25b",
1005
+ "metadata": {},
1006
+ "outputs": [],
1007
+ "source": []
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "execution_count": null,
1012
+ "id": "a37cbe6f-d57f-41e5-8ae1-38258da39d47",
1013
+ "metadata": {},
1014
+ "outputs": [],
1015
+ "source": [
1016
+ "import gradio as gr\n",
1017
+ "from model_loader import load_model\n",
1018
+ "\n",
1019
+ "import torch\n",
1020
+ "import torch.nn as nn\n",
1021
+ "import torch.nn.functional as F\n",
1022
+ "from torch.utils.data import DataLoader\n",
1023
+ "\n",
1024
+ "import re\n",
1025
+ "import numpy as np\n",
1026
+ "import os\n",
1027
+ "import pandas as pd\n",
1028
+ "import copy\n",
1029
+ "\n",
1030
+ "import transformers, datasets\n",
1031
+ "from transformers import AutoTokenizer\n",
1032
+ "from transformers import DataCollatorForTokenClassification\n",
1033
+ "\n",
1034
+ "from datasets import Dataset\n",
1035
+ "\n",
1036
+ "from scipy.special import expit\n",
1037
+ "\n",
1038
+ "import requests\n",
1039
+ "\n",
1040
+ "from gradio_molecule3d import Molecule3D\n",
1041
+ "\n",
1042
+ "# Biopython imports\n",
1043
+ "from Bio.PDB import PDBParser, Select, PDBIO\n",
1044
+ "from Bio.PDB.DSSP import DSSP\n",
1045
+ "from Bio.PDB import PDBList\n",
1046
+ "\n",
1047
+ "from matplotlib import cm # For color mapping\n",
1048
+ "from matplotlib.colors import Normalize\n",
1049
+ "\n",
1050
+ "# Load model and move to device\n",
1051
+ "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1052
+ "max_length = 1500\n",
1053
+ "model, tokenizer = load_model(checkpoint, max_length)\n",
1054
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1055
+ "model.to(device)\n",
1056
+ "model.eval()\n",
1057
+ "\n",
1058
+ "# Function to fetch a PDB file\n",
1059
+ "def fetch_pdb(pdb_id):\n",
1060
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1061
+ " pdb_path = f'pdb_files/{pdb_id}.pdb'\n",
1062
+ " os.makedirs('pdb_files', exist_ok=True)\n",
1063
+ " response = requests.get(pdb_url)\n",
1064
+ " if response.status_code == 200:\n",
1065
+ " with open(pdb_path, 'wb') as f:\n",
1066
+ " f.write(response.content)\n",
1067
+ " return pdb_path\n",
1068
+ " return None\n",
1069
+ "\n",
1070
+ "\n",
1071
+ "def normalize_scores(scores):\n",
1072
+ " min_score = np.min(scores)\n",
1073
+ " max_score = np.max(scores)\n",
1074
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1075
+ "\n",
1076
+ "def process_pdb(pdb_id, segment):\n",
1077
+ " pdb_path = fetch_pdb(pdb_id)\n",
1078
+ " if not pdb_path:\n",
1079
+ " return \"Failed to fetch PDB file\", None, None\n",
1080
+ " \n",
1081
+ " parser = PDBParser(QUIET=1)\n",
1082
+ " structure = parser.get_structure('protein', pdb_path)\n",
1083
+ " chain = structure[0][segment]\n",
1084
+ " \n",
1085
+ " # Comprehensive amino acid mapping\n",
1086
+ " aa_dict = {\n",
1087
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1088
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
1089
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
1090
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
1091
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
1092
+ " }\n",
1093
+ " \n",
1094
+ " # Exclude non-amino acid residues\n",
1095
+ " sequence = \"\".join(\n",
1096
+ " aa_dict[residue.get_resname().strip()] \n",
1097
+ " for residue in chain \n",
1098
+ " if residue.get_resname().strip() in aa_dict\n",
1099
+ " )\n",
1100
+ " \n",
1101
+ " # Prepare input for model prediction\n",
1102
+ " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n",
1103
+ " with torch.no_grad():\n",
1104
+ " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n",
1105
+ "\n",
1106
+ " # Calculate scores and normalize them\n",
1107
+ " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1108
+ " normalized_scores = normalize_scores(scores)\n",
1109
+ " \n",
1110
+ " # Prepare the result string, including only amino acid residues\n",
1111
+ " result_str = \"\\n\".join([\n",
1112
+ " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1113
+ " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1114
+ " ])\n",
1115
+ " \n",
1116
+ " # Save predictions to file\n",
1117
+ " with open(f\"{pdb_id}_predictions.txt\", \"w\") as f:\n",
1118
+ " f.write(result_str)\n",
1119
+ " \n",
1120
+ " return result_str, pdb_path, f\"{pdb_id}_predictions.txt\"\n",
1121
+ "\n",
1122
+ "reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
1123
+ "\n",
1124
+ "# Gradio UI\n",
1125
+ "with gr.Blocks() as demo:\n",
1126
+ " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
1127
+ "\n",
1128
+ " with gr.Row():\n",
1129
+ " pdb_input = gr.Textbox(value=\"2IWI\",\n",
1130
+ " label=\"PDB ID\",\n",
1131
+ " placeholder=\"Enter PDB ID here...\")\n",
1132
+ " segment_input = gr.Textbox(value=\"A\",\n",
1133
+ " label=\"Chain ID (Segment)\",\n",
1134
+ " placeholder=\"Enter Chain ID here...\")\n",
1135
+ " visualize_btn = gr.Button(\"Visualize Sructure\")\n",
1136
+ " prediction_btn = gr.Button(\"Predict Ligand Binding Site\")\n",
1137
+ "\n",
1138
+ " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1139
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1140
+ " download_output = gr.File(label=\"Download Predictions\")\n",
1141
+ "\n",
1142
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
1143
+ " prediction_btn.click(\n",
1144
+ " process_pdb, \n",
1145
+ " inputs=[pdb_input, segment_input], \n",
1146
+ " outputs=[predictions_output, molecule_output, download_output]\n",
1147
+ " )\n",
1148
+ "\n",
1149
+ " gr.Markdown(\"## Examples\")\n",
1150
+ " gr.Examples(\n",
1151
+ " examples=[\n",
1152
+ " [\"2IWI\"],\n",
1153
+ " [\"7RPZ\"],\n",
1154
+ " [\"3TJN\"]\n",
1155
+ " ],\n",
1156
+ " inputs=[pdb_input, segment_input], \n",
1157
+ " outputs=[predictions_output, molecule_output, download_output]\n",
1158
+ " )\n",
1159
+ "\n",
1160
+ "demo.launch(share=True)"
1161
+ ]
1162
+ },
1163
+ {
1164
+ "cell_type": "code",
1165
+ "execution_count": null,
1166
+ "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1167
+ "metadata": {},
1168
+ "outputs": [],
1169
+ "source": []
1170
+ }
1171
+ ],
1172
+ "metadata": {
1173
+ "kernelspec": {
1174
+ "display_name": "Python (LLM)",
1175
+ "language": "python",
1176
+ "name": "llm"
1177
+ },
1178
+ "language_info": {
1179
+ "codemirror_mode": {
1180
+ "name": "ipython",
1181
+ "version": 3
1182
+ },
1183
+ "file_extension": ".py",
1184
+ "mimetype": "text/x-python",
1185
+ "name": "python",
1186
+ "nbconvert_exporter": "python",
1187
+ "pygments_lexer": "ipython3",
1188
+ "version": "3.12.7"
1189
+ }
1190
+ },
1191
+ "nbformat": 4,
1192
+ "nbformat_minor": 5
1193
+ }
app.py CHANGED
@@ -1,4 +1,11 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
  from model_loader import load_model
3
 
4
  import torch
@@ -7,8 +14,6 @@ import torch.nn.functional as F
7
  from torch.utils.data import DataLoader
8
 
9
  import re
10
- import numpy as np
11
- import os
12
  import pandas as pd
13
  import copy
14
 
@@ -20,18 +25,6 @@ from datasets import Dataset
20
 
21
  from scipy.special import expit
22
 
23
- import requests
24
-
25
- from gradio_molecule3d import Molecule3D
26
-
27
- # Biopython imports
28
- from Bio.PDB import PDBParser, Select, PDBIO
29
- from Bio.PDB.DSSP import DSSP
30
- from Bio.PDB import PDBList
31
-
32
- from matplotlib import cm # For color mapping
33
- from matplotlib.colors import Normalize
34
-
35
  # Load model and move to device
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
@@ -40,23 +33,26 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  model.to(device)
41
  model.eval()
42
 
43
- # Function to fetch a PDB file
 
 
 
 
 
 
 
 
 
44
  def fetch_pdb(pdb_id):
45
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
46
- pdb_path = f'pdb_files/{pdb_id}.pdb'
47
- os.makedirs('pdb_files', exist_ok=True)
48
  response = requests.get(pdb_url)
49
  if response.status_code == 200:
50
  with open(pdb_path, 'wb') as f:
51
  f.write(response.content)
52
  return pdb_path
53
- return None
54
-
55
-
56
- def normalize_scores(scores):
57
- min_score = np.min(scores)
58
- max_score = np.max(scores)
59
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
60
 
61
  def process_pdb(pdb_id, segment):
62
  pdb_path = fetch_pdb(pdb_id)
@@ -65,7 +61,11 @@ def process_pdb(pdb_id, segment):
65
 
66
  parser = PDBParser(QUIET=1)
67
  structure = parser.get_structure('protein', pdb_path)
68
- chain = structure[0][segment]
 
 
 
 
69
 
70
  # Comprehensive amino acid mapping
71
  aa_dict = {
@@ -77,11 +77,10 @@ def process_pdb(pdb_id, segment):
77
  }
78
 
79
  # Exclude non-amino acid residues
80
- sequence = "".join(
81
- aa_dict[residue.get_resname().strip()]
82
- for residue in chain
83
  if residue.get_resname().strip() in aa_dict
84
- )
85
 
86
  # Prepare input for model prediction
87
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
@@ -91,54 +90,166 @@ def process_pdb(pdb_id, segment):
91
  # Calculate scores and normalize them
92
  scores = expit(outputs[:, 1] - outputs[:, 0])
93
  normalized_scores = normalize_scores(scores)
 
 
 
 
 
94
 
95
- # Prepare the result string, including only amino acid residues
96
- result_str = "\n".join([
97
- f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
98
- for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
99
- ])
100
-
101
- # Save predictions to file
102
- with open(f"{pdb_id}_predictions.txt", "w") as f:
103
  f.write(result_str)
104
 
105
- return result_str, pdb_path, f"{pdb_id}_predictions.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Gradio UI
110
  with gr.Blocks() as demo:
111
- gr.Markdown("# Protein Binding Site Prediction")
 
 
 
 
 
112
 
113
  with gr.Row():
114
- pdb_input = gr.Textbox(value="2IWI",
115
- label="PDB ID",
116
- placeholder="Enter PDB ID here...")
117
- segment_input = gr.Textbox(value="A",
118
- label="Chain ID (Segment)",
119
- placeholder="Enter Chain ID here...")
120
- visualize_btn = gr.Button("Visualize Sructure")
121
- prediction_btn = gr.Button("Predict Ligand Binding Site")
122
-
123
- molecule_output = Molecule3D(label="Protein Structure", reps=reps)
124
  predictions_output = gr.Textbox(label="Binding Site Predictions")
125
  download_output = gr.File(label="Download Predictions")
126
-
127
- visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
128
- prediction_btn.click(
129
- process_pdb,
130
- inputs=[pdb_input, segment_input],
131
- outputs=[predictions_output, molecule_output, download_output]
132
- )
133
-
134
  gr.Markdown("## Examples")
135
  gr.Examples(
136
  examples=[
137
- ["2IWI"],
138
- ["7RPZ"],
139
- ["3TJN"]
140
  ],
141
- inputs=[pdb_input, segment_input],
142
  outputs=[predictions_output, molecule_output, download_output]
143
  )
144
 
 
1
  import gradio as gr
2
+ import requests
3
+ from Bio.PDB import PDBParser
4
+ import numpy as np
5
+ import os
6
+ from gradio_molecule3d import Molecule3D
7
+
8
+
9
  from model_loader import load_model
10
 
11
  import torch
 
14
  from torch.utils.data import DataLoader
15
 
16
  import re
 
 
17
  import pandas as pd
18
  import copy
19
 
 
25
 
26
  from scipy.special import expit
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load model and move to device
29
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
30
  max_length = 1500
 
33
  model.to(device)
34
  model.eval()
35
 
36
+ def normalize_scores(scores):
37
+ min_score = np.min(scores)
38
+ max_score = np.max(scores)
39
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
40
+
41
+ def read_mol(pdb_path):
42
+ """Read PDB file and return its content as a string"""
43
+ with open(pdb_path, 'r') as f:
44
+ return f.read()
45
+
46
  def fetch_pdb(pdb_id):
47
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
48
+ pdb_path = f'{pdb_id}.pdb'
 
49
  response = requests.get(pdb_url)
50
  if response.status_code == 200:
51
  with open(pdb_path, 'wb') as f:
52
  f.write(response.content)
53
  return pdb_path
54
+ else:
55
+ return None
 
 
 
 
 
56
 
57
  def process_pdb(pdb_id, segment):
58
  pdb_path = fetch_pdb(pdb_id)
 
61
 
62
  parser = PDBParser(QUIET=1)
63
  structure = parser.get_structure('protein', pdb_path)
64
+
65
+ try:
66
+ chain = structure[0][segment]
67
+ except KeyError:
68
+ return "Invalid Chain ID", None, None
69
 
70
  # Comprehensive amino acid mapping
71
  aa_dict = {
 
77
  }
78
 
79
  # Exclude non-amino acid residues
80
+ sequence = [
81
+ residue for residue in chain
 
82
  if residue.get_resname().strip() in aa_dict
83
+ ]
84
 
85
  # Prepare input for model prediction
86
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
 
90
  # Calculate scores and normalize them
91
  scores = expit(outputs[:, 1] - outputs[:, 0])
92
  normalized_scores = normalize_scores(scores)
93
+
94
+ result_str = "\n".join(
95
+ f"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}"
96
+ for res, score in zip(sequence, normalized_scores)
97
+ )
98
 
99
+ # Save the predictions to a file
100
+ prediction_file = f"{pdb_id}_predictions.txt"
101
+ with open(prediction_file, "w") as f:
 
 
 
 
 
102
  f.write(result_str)
103
 
104
+ return result_str, molecule(pdb_path, random_scores, segment), prediction_file
105
+
106
+ def molecule(input_pdb, scores=None, segment='A'):
107
+ mol = read_mol(input_pdb) # Read PDB file content
108
+
109
+ # Prepare high-scoring residues script if scores are provided
110
+ high_score_script = ""
111
+ if scores is not None:
112
+ high_score_script = """
113
+ // Reset all styles first
114
+ viewer.getModel(0).setStyle({}, {});
115
+
116
+ // Show only the selected chain
117
+ viewer.getModel(0).setStyle(
118
+ {"chain": "%s"},
119
+ { cartoon: {colorscheme:"whiteCarbon"} }
120
+ );
121
+
122
+ // Highlight high-scoring residues only for the selected chain
123
+ let highScoreResidues = [%s];
124
+ viewer.getModel(0).setStyle(
125
+ {"chain": "%s", "resi": highScoreResidues},
126
+ {"stick": {"color": "red"}}
127
+ );
128
 
129
+ // Highlight high-scoring residues only for the selected chain
130
+ let highScoreResidues2 = [%s];
131
+ viewer.getModel(0).setStyle(
132
+ {"chain": "%s", "resi": highScoreResidues2},
133
+ {"stick": {"color": "orange"}}
134
+ );
135
+ """ % (segment,
136
+ ", ".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),
137
+ segment,
138
+ ", ".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),
139
+ segment)
140
+
141
+ html_content = f"""
142
+ <!DOCTYPE html>
143
+ <html>
144
+ <head>
145
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
146
+ <style>
147
+ .mol-container {{
148
+ width: 100%;
149
+ height: 700px;
150
+ position: relative;
151
+ }}
152
+ </style>
153
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script>
154
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
155
+ </head>
156
+ <body>
157
+ <div id="container" class="mol-container"></div>
158
+ <script>
159
+ let pdb = `{mol}`; // Use template literal to properly escape PDB content
160
+ $(document).ready(function () {{
161
+ let element = $("#container");
162
+ let config = {{ backgroundColor: "white" }};
163
+ let viewer = $3Dmol.createViewer(element, config);
164
+ viewer.addModel(pdb, "pdb");
165
+
166
+ // Reset all styles and show only selected chain
167
+ viewer.getModel(0).setStyle(
168
+ {{"chain": "{segment}"}},
169
+ {{ cartoon: {{ colorscheme:"whiteCarbon" }} }}
170
+ );
171
+
172
+ {high_score_script}
173
+
174
+ // Add hover functionality
175
+ viewer.setHoverable(
176
+ {{}},
177
+ true,
178
+ function(atom, viewer, event, container) {{
179
+ if (!atom.label) {{
180
+ atom.label = viewer.addLabel(
181
+ atom.resn + ":" + atom.atom,
182
+ {{
183
+ position: atom,
184
+ backgroundColor: 'mintcream',
185
+ fontColor: 'black',
186
+ fontSize: 12,
187
+ padding: 2
188
+ }}
189
+ );
190
+ }}
191
+ }},
192
+ function(atom, viewer) {{
193
+ if (atom.label) {{
194
+ viewer.removeLabel(atom.label);
195
+ delete atom.label;
196
+ }}
197
+ }}
198
+ );
199
+
200
+ viewer.zoomTo();
201
+ viewer.render();
202
+ viewer.zoom(0.8, 2000);
203
+ }});
204
+ </script>
205
+ </body>
206
+ </html>
207
+ """
208
+
209
+ # Return the HTML content within an iframe safely encoded for special characters
210
+ return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
211
+
212
+ reps = [
213
+ {
214
+ "model": 0,
215
+ "style": "cartoon",
216
+ "color": "whiteCarbon",
217
+ "residue_range": "",
218
+ "around": 0,
219
+ "byres": False,
220
+ }
221
+ ]
222
 
223
  # Gradio UI
224
  with gr.Blocks() as demo:
225
+ gr.Markdown("# Protein Binding Site Prediction (Random Scores)")
226
+ with gr.Row():
227
+ pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
228
+ visualize_btn = gr.Button("Visualize Structure")
229
+
230
+ molecule_output2 = Molecule3D(label="Protein Structure", reps=reps)
231
 
232
  with gr.Row():
233
+ pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
234
+ segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
235
+ prediction_btn = gr.Button("Predict Random Binding Site Scores")
236
+
237
+ molecule_output = gr.HTML(label="Protein Structure")
 
 
 
 
 
238
  predictions_output = gr.Textbox(label="Binding Site Predictions")
239
  download_output = gr.File(label="Download Predictions")
240
+
241
+ visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)
242
+
243
+ prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])
244
+
 
 
 
245
  gr.Markdown("## Examples")
246
  gr.Examples(
247
  examples=[
248
+ ["2IWI", "A"],
249
+ ["7RPZ", "B"],
250
+ ["3TJN", "C"]
251
  ],
252
+ inputs=[pdb_input, segment_input],
253
  outputs=[predictions_output, molecule_output, download_output]
254
  )
255
 
model_loader.ipynb DELETED
@@ -1,871 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 38,
6
- "id": "14ff5741-629c-445a-a8a9-b3d9db1f3ddb",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import torch\n",
11
- "import torch.nn as nn\n",
12
- "import torch.nn.functional as F\n",
13
- "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
14
- "from torch.utils.data import DataLoader\n",
15
- "\n",
16
- "import re\n",
17
- "import numpy as np\n",
18
- "import os\n",
19
- "import pandas as pd\n",
20
- "import copy\n",
21
- "\n",
22
- "import transformers, datasets\n",
23
- "from transformers.modeling_outputs import TokenClassifierOutput\n",
24
- "from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack\n",
25
- "from transformers.utils.model_parallel_utils import assert_device_map, get_device_map\n",
26
- "from transformers import T5EncoderModel, T5Tokenizer\n",
27
- "from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel\n",
28
- "from transformers import AutoTokenizer\n",
29
- "from transformers import TrainingArguments, Trainer, set_seed\n",
30
- "from transformers import DataCollatorForTokenClassification\n",
31
- "\n",
32
- "from dataclasses import dataclass\n",
33
- "from typing import Dict, List, Optional, Tuple, Union\n",
34
- "\n",
35
- "# for custom DataCollator\n",
36
- "from transformers.data.data_collator import DataCollatorMixin\n",
37
- "from transformers.tokenization_utils_base import PreTrainedTokenizerBase\n",
38
- "from transformers.utils import PaddingStrategy\n",
39
- "\n",
40
- "from datasets import Dataset\n",
41
- "\n",
42
- "from scipy.special import expit\n",
43
- "\n",
44
- "import peft\n",
45
- "from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig"
46
- ]
47
- },
48
- {
49
- "cell_type": "code",
50
- "execution_count": 6,
51
- "id": "5ec16a71-ed5d-46a6-98b2-55bc5d0fbe07",
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc\n",
56
- "ffn_head=False #False\n",
57
- "transformer_head=False\n",
58
- "custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc"
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 8,
64
- "id": "cc7151ca-0daf-4e75-a865-ab52f9b28f2e",
65
- "metadata": {},
66
- "outputs": [],
67
- "source": [
68
- "class ClassConfig:\n",
69
- " def __init__(self, dropout=0.2, num_labels=3):\n",
70
- " self.dropout_rate = dropout\n",
71
- " self.num_labels = num_labels\n",
72
- "\n",
73
- "class T5EncoderForTokenClassification(T5PreTrainedModel):\n",
74
- "\n",
75
- " def __init__(self, config: T5Config, class_config: ClassConfig):\n",
76
- " super().__init__(config)\n",
77
- " self.num_labels = class_config.num_labels\n",
78
- " self.config = config\n",
79
- "\n",
80
- " self.shared = nn.Embedding(config.vocab_size, config.d_model)\n",
81
- "\n",
82
- " encoder_config = copy.deepcopy(config)\n",
83
- " encoder_config.use_cache = False\n",
84
- " encoder_config.is_encoder_decoder = False\n",
85
- " self.encoder = T5Stack(encoder_config, self.shared)\n",
86
- "\n",
87
- " self.dropout = nn.Dropout(class_config.dropout_rate)\n",
88
- "\n",
89
- " # Initialize different heads based on class_config\n",
90
- " if cnn_head:\n",
91
- " self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)\n",
92
- " self.classifier = nn.Linear(512, class_config.num_labels)\n",
93
- " elif ffn_head:\n",
94
- " # Multi-layer feed-forward network (FFN) head\n",
95
- " self.ffn = nn.Sequential(\n",
96
- " nn.Linear(config.hidden_size, 512),\n",
97
- " nn.ReLU(),\n",
98
- " nn.Linear(512, 256),\n",
99
- " nn.ReLU(),\n",
100
- " nn.Linear(256, class_config.num_labels)\n",
101
- " )\n",
102
- " elif transformer_head:\n",
103
- " # Transformer layer head\n",
104
- " encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)\n",
105
- " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)\n",
106
- " self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)\n",
107
- " else:\n",
108
- " # Default classification head\n",
109
- " self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)\n",
110
- " \n",
111
- " self.post_init()\n",
112
- "\n",
113
- " # Model parallel\n",
114
- " self.model_parallel = False\n",
115
- " self.device_map = None\n",
116
- "\n",
117
- " def parallelize(self, device_map=None):\n",
118
- " self.device_map = (\n",
119
- " get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))\n",
120
- " if device_map is None\n",
121
- " else device_map\n",
122
- " )\n",
123
- " assert_device_map(self.device_map, len(self.encoder.block))\n",
124
- " self.encoder.parallelize(self.device_map)\n",
125
- " self.classifier = self.classifier.to(self.encoder.first_device)\n",
126
- " self.model_parallel = True\n",
127
- "\n",
128
- " def deparallelize(self):\n",
129
- " self.encoder.deparallelize()\n",
130
- " self.encoder = self.encoder.to(\"cpu\")\n",
131
- " self.model_parallel = False\n",
132
- " self.device_map = None\n",
133
- " torch.cuda.empty_cache()\n",
134
- "\n",
135
- " def get_input_embeddings(self):\n",
136
- " return self.shared\n",
137
- "\n",
138
- " def set_input_embeddings(self, new_embeddings):\n",
139
- " self.shared = new_embeddings\n",
140
- " self.encoder.set_input_embeddings(new_embeddings)\n",
141
- "\n",
142
- " def get_encoder(self):\n",
143
- " return self.encoder\n",
144
- "\n",
145
- " def _prune_heads(self, heads_to_prune):\n",
146
- " for layer, heads in heads_to_prune.items():\n",
147
- " self.encoder.layer[layer].attention.prune_heads(heads)\n",
148
- "\n",
149
- " def forward(\n",
150
- " self,\n",
151
- " input_ids=None,\n",
152
- " attention_mask=None,\n",
153
- " head_mask=None,\n",
154
- " inputs_embeds=None,\n",
155
- " labels=None,\n",
156
- " output_attentions=None,\n",
157
- " output_hidden_states=None,\n",
158
- " return_dict=None,\n",
159
- " ):\n",
160
- " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
161
- "\n",
162
- " outputs = self.encoder(\n",
163
- " input_ids=input_ids,\n",
164
- " attention_mask=attention_mask,\n",
165
- " inputs_embeds=inputs_embeds,\n",
166
- " head_mask=head_mask,\n",
167
- " output_attentions=output_attentions,\n",
168
- " output_hidden_states=output_hidden_states,\n",
169
- " return_dict=return_dict,\n",
170
- " )\n",
171
- "\n",
172
- " sequence_output = outputs[0]\n",
173
- " sequence_output = self.dropout(sequence_output)\n",
174
- "\n",
175
- " # Forward pass through the selected head\n",
176
- " if cnn_head:\n",
177
- " # CNN head\n",
178
- " sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN\n",
179
- " cnn_output = self.cnn(sequence_output)\n",
180
- " cnn_output = F.relu(cnn_output)\n",
181
- " cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier\n",
182
- " logits = self.classifier(cnn_output)\n",
183
- " elif ffn_head:\n",
184
- " # FFN head\n",
185
- " logits = self.ffn(sequence_output)\n",
186
- " elif transformer_head:\n",
187
- " # Transformer head\n",
188
- " transformer_output = self.transformer_encoder(sequence_output)\n",
189
- " logits = self.classifier(transformer_output)\n",
190
- " else:\n",
191
- " # Default classification head\n",
192
- " logits = self.classifier(sequence_output)\n",
193
- "\n",
194
- " loss = None\n",
195
- " if labels is not None:\n",
196
- " loss_fct = CrossEntropyLoss()\n",
197
- " active_loss = attention_mask.view(-1) == 1\n",
198
- " active_logits = logits.view(-1, self.num_labels)\n",
199
- " active_labels = torch.where(\n",
200
- " active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)\n",
201
- " )\n",
202
- " valid_logits = active_logits[active_labels != -100]\n",
203
- " valid_labels = active_labels[active_labels != -100]\n",
204
- " valid_labels = valid_labels.to(valid_logits.device)\n",
205
- " valid_labels = valid_labels.long()\n",
206
- " loss = loss_fct(valid_logits, valid_labels)\n",
207
- "\n",
208
- " if not return_dict:\n",
209
- " output = (logits,) + outputs[2:]\n",
210
- " return ((loss,) + output) if loss is not None else output\n",
211
- "\n",
212
- " return TokenClassifierOutput(\n",
213
- " loss=loss,\n",
214
- " logits=logits,\n",
215
- " hidden_states=outputs.hidden_states,\n",
216
- " attentions=outputs.attentions,\n",
217
- " )"
218
- ]
219
- },
220
- {
221
- "cell_type": "code",
222
- "execution_count": 10,
223
- "id": "e5e751ba-f4d3-4a28-bea0-82633f1dabb4",
224
- "metadata": {},
225
- "outputs": [],
226
- "source": [
227
- "# Modifies an existing transformer and introduce the LoRA layers\n",
228
- "\n",
229
- "class CustomLoRAConfig:\n",
230
- " def __init__(self):\n",
231
- " self.lora_rank = 4\n",
232
- " self.lora_init_scale = 0.01\n",
233
- " self.lora_modules = \".*SelfAttention|.*EncDecAttention\"\n",
234
- " self.lora_layers = \"q|k|v|o\"\n",
235
- " self.trainable_param_names = \".*layer_norm.*|.*lora_[ab].*\"\n",
236
- " self.lora_scaling_rank = 1\n",
237
- " # lora_modules and lora_layers are speicified with regular expressions\n",
238
- " # see https://www.w3schools.com/python/python_regex.asp for reference\n",
239
- " \n",
240
- "class LoRALinear(nn.Module):\n",
241
- " def __init__(self, linear_layer, rank, scaling_rank, init_scale):\n",
242
- " super().__init__()\n",
243
- " self.in_features = linear_layer.in_features\n",
244
- " self.out_features = linear_layer.out_features\n",
245
- " self.rank = rank\n",
246
- " self.scaling_rank = scaling_rank\n",
247
- " self.weight = linear_layer.weight\n",
248
- " self.bias = linear_layer.bias\n",
249
- " if self.rank > 0:\n",
250
- " self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)\n",
251
- " if init_scale < 0:\n",
252
- " self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)\n",
253
- " else:\n",
254
- " self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))\n",
255
- " if self.scaling_rank:\n",
256
- " self.multi_lora_a = nn.Parameter(\n",
257
- " torch.ones(self.scaling_rank, linear_layer.in_features)\n",
258
- " + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale\n",
259
- " )\n",
260
- " if init_scale < 0:\n",
261
- " self.multi_lora_b = nn.Parameter(\n",
262
- " torch.ones(linear_layer.out_features, self.scaling_rank)\n",
263
- " + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale\n",
264
- " )\n",
265
- " else:\n",
266
- " self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))\n",
267
- "\n",
268
- " def forward(self, input):\n",
269
- " if self.scaling_rank == 1 and self.rank == 0:\n",
270
- " # parsimonious implementation for ia3 and lora scaling\n",
271
- " if self.multi_lora_a.requires_grad:\n",
272
- " hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)\n",
273
- " else:\n",
274
- " hidden = F.linear(input, self.weight, self.bias)\n",
275
- " if self.multi_lora_b.requires_grad:\n",
276
- " hidden = hidden * self.multi_lora_b.flatten()\n",
277
- " return hidden\n",
278
- " else:\n",
279
- " # general implementation for lora (adding and scaling)\n",
280
- " weight = self.weight\n",
281
- " if self.scaling_rank:\n",
282
- " weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank\n",
283
- " if self.rank:\n",
284
- " weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank\n",
285
- " return F.linear(input, weight, self.bias)\n",
286
- "\n",
287
- " def extra_repr(self):\n",
288
- " return \"in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}\".format(\n",
289
- " self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank\n",
290
- " )\n",
291
- "\n",
292
- "\n",
293
- "def modify_with_lora(transformer, config):\n",
294
- " for m_name, module in dict(transformer.named_modules()).items():\n",
295
- " if re.fullmatch(config.lora_modules, m_name):\n",
296
- " for c_name, layer in dict(module.named_children()).items():\n",
297
- " if re.fullmatch(config.lora_layers, c_name):\n",
298
- " assert isinstance(\n",
299
- " layer, nn.Linear\n",
300
- " ), f\"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}.\"\n",
301
- " setattr(\n",
302
- " module,\n",
303
- " c_name,\n",
304
- " LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),\n",
305
- " )\n",
306
- " return transformer\n",
307
- "\n"
308
- ]
309
- },
310
- {
311
- "cell_type": "code",
312
- "execution_count": 12,
313
- "id": "43a56311-3279-466a-bc95-590381f1b13c",
314
- "metadata": {},
315
- "outputs": [],
316
- "source": [
317
- "def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True):\n",
318
- " # Load model and tokenizer\n",
319
- "\n",
320
- " if \"ankh\" in checkpoint :\n",
321
- " model = T5EncoderModel.from_pretrained(checkpoint)\n",
322
- " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
323
- "\n",
324
- " elif \"prot_t5\" in checkpoint:\n",
325
- " # possible to load the half precision model (thanks to @pawel-rezo for pointing that out)\n",
326
- " if half_precision and deepspeed:\n",
327
- " #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)\n",
328
- " #model = T5EncoderModel.from_pretrained(\"Rostlab/prot_t5_xl_half_uniref50-enc\", torch_dtype=torch.float16)#.to(torch.device('cuda')\n",
329
- " tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False)\n",
330
- " model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'))\n",
331
- " else:\n",
332
- " model = T5EncoderModel.from_pretrained(checkpoint)\n",
333
- " tokenizer = T5Tokenizer.from_pretrained(checkpoint)\n",
334
- " \n",
335
- " elif \"ProstT5\" in checkpoint:\n",
336
- " if half_precision and deepspeed: \n",
337
- " tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False)\n",
338
- " model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'))\n",
339
- " else:\n",
340
- " model = T5EncoderModel.from_pretrained(checkpoint)\n",
341
- " tokenizer = T5Tokenizer.from_pretrained(checkpoint) \n",
342
- " \n",
343
- " # Create new Classifier model with PT5 dimensions\n",
344
- " class_config=ClassConfig(num_labels=num_labels)\n",
345
- " class_model=T5EncoderForTokenClassification(model.config,class_config)\n",
346
- " \n",
347
- " # Set encoder and embedding weights to checkpoint weights\n",
348
- " class_model.shared=model.shared\n",
349
- " class_model.encoder=model.encoder \n",
350
- " \n",
351
- " # Delete the checkpoint model\n",
352
- " model=class_model\n",
353
- " del class_model\n",
354
- " \n",
355
- " if full == True:\n",
356
- " return model, tokenizer \n",
357
- " \n",
358
- " # Print number of trainable parameters\n",
359
- " model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
360
- " params = sum([np.prod(p.size()) for p in model_parameters])\n",
361
- " print(\"T5_Classfier\\nTrainable Parameter: \"+ str(params)) \n",
362
- "\n",
363
- " if custom_lora:\n",
364
- " #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed\n",
365
- " # Add model modification lora\n",
366
- " config = CustomLoRAConfig()\n",
367
- " \n",
368
- " # Add LoRA layers\n",
369
- " model = modify_with_lora(model, config)\n",
370
- " \n",
371
- " # Freeze Embeddings and Encoder (except LoRA)\n",
372
- " for (param_name, param) in model.shared.named_parameters():\n",
373
- " param.requires_grad = False\n",
374
- " for (param_name, param) in model.encoder.named_parameters():\n",
375
- " param.requires_grad = False \n",
376
- " \n",
377
- " for (param_name, param) in model.named_parameters():\n",
378
- " if re.fullmatch(config.trainable_param_names, param_name):\n",
379
- " param.requires_grad = True\n",
380
- "\n",
381
- " else:\n",
382
- " # lora modification\n",
383
- " peft_config = LoraConfig(\n",
384
- " r=4, lora_alpha=1, bias=\"all\", target_modules=[\"q\",\"k\",\"v\",\"o\"]\n",
385
- " )\n",
386
- " \n",
387
- " model = inject_adapter_in_model(peft_config, model)\n",
388
- " \n",
389
- " # Unfreeze the prediction head\n",
390
- " for (param_name, param) in model.classifier.named_parameters():\n",
391
- " param.requires_grad = True \n",
392
- "\n",
393
- " # Print trainable Parameter \n",
394
- " model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
395
- " params = sum([np.prod(p.size()) for p in model_parameters])\n",
396
- " print(\"T5_LoRA_Classfier\\nTrainable Parameter: \"+ str(params) + \"\\n\")\n",
397
- " \n",
398
- " return model, tokenizer"
399
- ]
400
- },
401
- {
402
- "cell_type": "code",
403
- "execution_count": 14,
404
- "id": "7ba720bc-a003-4984-a965-cb2f42344e85",
405
- "metadata": {},
406
- "outputs": [],
407
- "source": [
408
- "class EsmForTokenClassificationCustom(EsmPreTrainedModel):\n",
409
- " _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n",
410
- " _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"cnn\", r\"ffn\", r\"transformer\"]\n",
411
- "\n",
412
- " def __init__(self, config):\n",
413
- " super().__init__(config)\n",
414
- " self.num_labels = config.num_labels\n",
415
- " self.esm = EsmModel(config, add_pooling_layer=False)\n",
416
- " self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
417
- "\n",
418
- " if cnn_head:\n",
419
- " self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)\n",
420
- " self.classifier = nn.Linear(512, config.num_labels)\n",
421
- " elif ffn_head:\n",
422
- " # Multi-layer feed-forward network (FFN) as an alternative head\n",
423
- " self.ffn = nn.Sequential(\n",
424
- " nn.Linear(config.hidden_size, 512),\n",
425
- " nn.ReLU(),\n",
426
- " nn.Linear(512, 256),\n",
427
- " nn.ReLU(),\n",
428
- " nn.Linear(256, config.num_labels)\n",
429
- " )\n",
430
- " elif transformer_head:\n",
431
- " # Transformer layer as an alternative head\n",
432
- " encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)\n",
433
- " self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)\n",
434
- " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
435
- " else:\n",
436
- " # Default classification head\n",
437
- " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
438
- "\n",
439
- " self.init_weights()\n",
440
- "\n",
441
- " def forward(\n",
442
- " self,\n",
443
- " input_ids: Optional[torch.LongTensor] = None,\n",
444
- " attention_mask: Optional[torch.Tensor] = None,\n",
445
- " position_ids: Optional[torch.LongTensor] = None,\n",
446
- " head_mask: Optional[torch.Tensor] = None,\n",
447
- " inputs_embeds: Optional[torch.FloatTensor] = None,\n",
448
- " labels: Optional[torch.LongTensor] = None,\n",
449
- " output_attentions: Optional[bool] = None,\n",
450
- " output_hidden_states: Optional[bool] = None,\n",
451
- " return_dict: Optional[bool] = None,\n",
452
- " ) -> Union[Tuple, TokenClassifierOutput]:\n",
453
- " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
454
- " outputs = self.esm(\n",
455
- " input_ids,\n",
456
- " attention_mask=attention_mask,\n",
457
- " position_ids=position_ids,\n",
458
- " head_mask=head_mask,\n",
459
- " inputs_embeds=inputs_embeds,\n",
460
- " output_attentions=output_attentions,\n",
461
- " output_hidden_states=output_hidden_states,\n",
462
- " return_dict=return_dict,\n",
463
- " )\n",
464
- " \n",
465
- " sequence_output = outputs[0]\n",
466
- " sequence_output = self.dropout(sequence_output)\n",
467
- "\n",
468
- " if cnn_head:\n",
469
- " sequence_output = sequence_output.transpose(1, 2)\n",
470
- " sequence_output = self.cnn(sequence_output)\n",
471
- " sequence_output = sequence_output.transpose(1, 2)\n",
472
- " logits = self.classifier(sequence_output)\n",
473
- " elif ffn_head:\n",
474
- " logits = self.ffn(sequence_output)\n",
475
- " elif transformer_head:\n",
476
- " # Apply transformer encoder for the transformer head\n",
477
- " sequence_output = self.transformer_encoder(sequence_output)\n",
478
- " logits = self.classifier(sequence_output)\n",
479
- " else:\n",
480
- " logits = self.classifier(sequence_output)\n",
481
- "\n",
482
- " loss = None\n",
483
- " if labels is not None:\n",
484
- " loss_fct = CrossEntropyLoss()\n",
485
- " active_loss = attention_mask.view(-1) == 1\n",
486
- " active_logits = logits.view(-1, self.num_labels)\n",
487
- " active_labels = torch.where(\n",
488
- " active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)\n",
489
- " )\n",
490
- " valid_logits = active_logits[active_labels != -100]\n",
491
- " valid_labels = active_labels[active_labels != -100]\n",
492
- " valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0')\n",
493
- " loss = loss_fct(valid_logits, valid_labels)\n",
494
- "\n",
495
- " if not return_dict:\n",
496
- " output = (logits,) + outputs[2:]\n",
497
- " return ((loss,) + output) if loss is not None else output\n",
498
- "\n",
499
- " return TokenClassifierOutput(\n",
500
- " loss=loss,\n",
501
- " logits=logits,\n",
502
- " hidden_states=outputs.hidden_states,\n",
503
- " attentions=outputs.attentions,\n",
504
- " )\n",
505
- "\n",
506
- " def _init_weights(self, module):\n",
507
- " if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):\n",
508
- " module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
509
- " if module.bias is not None:\n",
510
- " module.bias.data.zero_()\n",
511
- "\n",
512
- "# based on transformers DataCollatorForTokenClassification\n",
513
- "@dataclass\n",
514
- "class DataCollatorForTokenClassificationESM(DataCollatorMixin):\n",
515
- " \"\"\"\n",
516
- " Data collator that will dynamically pad the inputs received, as well as the labels.\n",
517
- " Args:\n",
518
- " tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):\n",
519
- " The tokenizer used for encoding the data.\n",
520
- " padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):\n",
521
- " Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n",
522
- " among:\n",
523
- " - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single\n",
524
- " sequence is provided).\n",
525
- " - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum\n",
526
- " acceptable input length for the model if that argument is not provided.\n",
527
- " - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).\n",
528
- " max_length (`int`, *optional*):\n",
529
- " Maximum length of the returned list and optionally padding length (see above).\n",
530
- " pad_to_multiple_of (`int`, *optional*):\n",
531
- " If set will pad the sequence to a multiple of the provided value.\n",
532
- " This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=\n",
533
- " 7.5 (Volta).\n",
534
- " label_pad_token_id (`int`, *optional*, defaults to -100):\n",
535
- " The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).\n",
536
- " return_tensors (`str`):\n",
537
- " The type of Tensor to return. Allowable values are \"np\", \"pt\" and \"tf\".\n",
538
- " \"\"\"\n",
539
- "\n",
540
- " tokenizer: PreTrainedTokenizerBase\n",
541
- " padding: Union[bool, str, PaddingStrategy] = True\n",
542
- " max_length: Optional[int] = None\n",
543
- " pad_to_multiple_of: Optional[int] = None\n",
544
- " label_pad_token_id: int = -100\n",
545
- " return_tensors: str = \"pt\"\n",
546
- "\n",
547
- " def torch_call(self, features):\n",
548
- " import torch\n",
549
- "\n",
550
- " label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n",
551
- " labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None\n",
552
- "\n",
553
- " no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]\n",
554
- "\n",
555
- " batch = self.tokenizer.pad(\n",
556
- " no_labels_features,\n",
557
- " padding=self.padding,\n",
558
- " max_length=self.max_length,\n",
559
- " pad_to_multiple_of=self.pad_to_multiple_of,\n",
560
- " return_tensors=\"pt\",\n",
561
- " )\n",
562
- "\n",
563
- " if labels is None:\n",
564
- " return batch\n",
565
- "\n",
566
- " sequence_length = batch[\"input_ids\"].shape[1]\n",
567
- " padding_side = self.tokenizer.padding_side\n",
568
- "\n",
569
- " def to_list(tensor_or_iterable):\n",
570
- " if isinstance(tensor_or_iterable, torch.Tensor):\n",
571
- " return tensor_or_iterable.tolist()\n",
572
- " return list(tensor_or_iterable)\n",
573
- "\n",
574
- " if padding_side == \"right\":\n",
575
- " batch[label_name] = [\n",
576
- " # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels\n",
577
- " # changed to pad the special tokens at the beginning and end of the sequence\n",
578
- " [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels\n",
579
- " ]\n",
580
- " else:\n",
581
- " batch[label_name] = [\n",
582
- " [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels\n",
583
- " ]\n",
584
- "\n",
585
- " batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)\n",
586
- " return batch\n",
587
- "\n",
588
- "def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):\n",
589
- " \"\"\"Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.\"\"\"\n",
590
- " import torch\n",
591
- "\n",
592
- " # Tensorize if necessary.\n",
593
- " if isinstance(examples[0], (list, tuple, np.ndarray)):\n",
594
- " examples = [torch.tensor(e, dtype=torch.long) for e in examples]\n",
595
- "\n",
596
- " length_of_first = examples[0].size(0)\n",
597
- "\n",
598
- " # Check if padding is necessary.\n",
599
- "\n",
600
- " are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)\n",
601
- " if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):\n",
602
- " return torch.stack(examples, dim=0)\n",
603
- "\n",
604
- " # If yes, check if we have a `pad_token`.\n",
605
- " if tokenizer._pad_token is None:\n",
606
- " raise ValueError(\n",
607
- " \"You are attempting to pad samples but the tokenizer you are using\"\n",
608
- " f\" ({tokenizer.__class__.__name__}) does not have a pad token.\"\n",
609
- " )\n",
610
- "\n",
611
- " # Creating the full tensor and filling it with our data.\n",
612
- " max_length = max(x.size(0) for x in examples)\n",
613
- " if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n",
614
- " max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n",
615
- " result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)\n",
616
- " for i, example in enumerate(examples):\n",
617
- " if tokenizer.padding_side == \"right\":\n",
618
- " result[i, : example.shape[0]] = example\n",
619
- " else:\n",
620
- " result[i, -example.shape[0] :] = example\n",
621
- " return result\n",
622
- "\n",
623
- "def tolist(x):\n",
624
- " if isinstance(x, list):\n",
625
- " return x\n",
626
- " elif hasattr(x, \"numpy\"): # Checks for TF tensors without needing the import\n",
627
- " x = x.numpy()\n",
628
- " return x.tolist()"
629
- ]
630
- },
631
- {
632
- "cell_type": "code",
633
- "execution_count": 16,
634
- "id": "ea511812-1244-4e51-b63c-b4da7822f0b7",
635
- "metadata": {},
636
- "outputs": [],
637
- "source": [
638
- "#load ESM2 models\n",
639
- "def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True):\n",
640
- " \n",
641
- " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
642
- "\n",
643
- " \n",
644
- " if half_precision and deepspeed:\n",
645
- " model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, \n",
646
- " num_labels = num_labels, \n",
647
- " ignore_mismatched_sizes=True,\n",
648
- " torch_dtype = torch.float16)\n",
649
- " else:\n",
650
- " model = EsmForTokenClassificationCustom.from_pretrained(checkpoint, \n",
651
- " num_labels = num_labels,\n",
652
- " ignore_mismatched_sizes=True)\n",
653
- " \n",
654
- " if full == True:\n",
655
- " return model, tokenizer \n",
656
- " \n",
657
- " peft_config = LoraConfig(\n",
658
- " r=4, lora_alpha=1, bias=\"all\", target_modules=[\"query\",\"key\",\"value\",\"dense\"]\n",
659
- " )\n",
660
- " \n",
661
- " model = inject_adapter_in_model(peft_config, model)\n",
662
- "\n",
663
- " #model.gradient_checkpointing_enable()\n",
664
- " \n",
665
- " # Unfreeze the prediction head\n",
666
- " for (param_name, param) in model.classifier.named_parameters():\n",
667
- " param.requires_grad = True \n",
668
- " \n",
669
- " return model, tokenizer"
670
- ]
671
- },
672
- {
673
- "cell_type": "code",
674
- "execution_count": 22,
675
- "id": "8941bbbb-57c5-4f3d-89d9-12b2d306e7a1",
676
- "metadata": {},
677
- "outputs": [],
678
- "source": [
679
- "checkpoint='../Pretrained/Rostlab/prot_t5_xl_uniref50'\n",
680
- "best_model_path='../refined_models/ChallengeFinetuning/Rostlab/prot_t5_xl_uniref50/manual_checkpoint/cpt.pth'\n",
681
- "full=False\n",
682
- "deepspeed=False\n",
683
- "mixed=False \n",
684
- "num_labels=2"
685
- ]
686
- },
687
- {
688
- "cell_type": "code",
689
- "execution_count": null,
690
- "id": "4f007331-34d4-4c1d-9311-e91db23d9ed5",
691
- "metadata": {},
692
- "outputs": [],
693
- "source": [
694
- "/home/frohlkin/Projects/PLM/Publication/hf_webpage/pretrained"
695
- ]
696
- },
697
- {
698
- "cell_type": "code",
699
- "execution_count": 24,
700
- "id": "18d4ad06-b195-4cc6-a3c8-fa3e761838dc",
701
- "metadata": {},
702
- "outputs": [
703
- {
704
- "name": "stdout",
705
- "output_type": "stream",
706
- "text": [
707
- "../Pretrained/Rostlab/prot_t5_xl_uniref50 2 False False False\n",
708
- "T5_Classfier\n",
709
- "Trainable Parameter: 1209716226\n",
710
- "T5_LoRA_Classfier\n",
711
- "Trainable Parameter: 4082178\n",
712
- "\n"
713
- ]
714
- },
715
- {
716
- "data": {
717
- "text/plain": [
718
- "<All keys matched successfully>"
719
- ]
720
- },
721
- "execution_count": 24,
722
- "metadata": {},
723
- "output_type": "execute_result"
724
- }
725
- ],
726
- "source": [
727
- "print(checkpoint, num_labels, mixed, full, deepspeed)\n",
728
- " \n",
729
- "# Determine model type and load accordingly\n",
730
- "if \"esm\" in checkpoint:\n",
731
- " model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed)\n",
732
- "else:\n",
733
- " model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed)\n",
734
- "\n",
735
- "# Load the best model state\n",
736
- "state_dict = torch.load(best_model_path, weights_only=True)\n",
737
- "model.load_state_dict(state_dict)"
738
- ]
739
- },
740
- {
741
- "cell_type": "code",
742
- "execution_count": 30,
743
- "id": "4e215923-dfe2-4426-aedf-5cb81f7f0db2",
744
- "metadata": {},
745
- "outputs": [],
746
- "source": [
747
- "test_one_letter_sequence='AWYAAK'\n",
748
- "max_length=1500"
749
- ]
750
- },
751
- {
752
- "cell_type": "code",
753
- "execution_count": 40,
754
- "id": "7174ea02-ed51-46f5-84c0-6bcd760670d4",
755
- "metadata": {},
756
- "outputs": [
757
- {
758
- "data": {
759
- "text/plain": [
760
- "(7,)"
761
- ]
762
- },
763
- "execution_count": 40,
764
- "metadata": {},
765
- "output_type": "execute_result"
766
- }
767
- ],
768
- "source": [
769
- "def create_dataset(tokenizer,seqs,labels,checkpoint):\n",
770
- " \n",
771
- " tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)\n",
772
- " dataset = Dataset.from_dict(tokenized)\n",
773
- " \n",
774
- " if (\"esm\" in checkpoint) or (\"ProstT5\" in checkpoint):\n",
775
- " labels = [l[:max_length-2] for l in labels] \n",
776
- " else:\n",
777
- " labels = [l[:max_length-1] for l in labels] \n",
778
- " \n",
779
- " dataset = dataset.add_column(\"labels\", labels)\n",
780
- " \n",
781
- " return dataset\n",
782
- " \n",
783
- "def convert_predictions(input_logits):\n",
784
- " all_probs = []\n",
785
- " for logits in input_logits:\n",
786
- " logits = logits.reshape(-1, 2)\n",
787
- "\n",
788
- " # Mask out irrelevant regions\n",
789
- " # Compute probabilities for class 1\n",
790
- " probabilities_class1 = expit(logits[:, 1] - logits[:, 0])\n",
791
- " \n",
792
- " all_probs.append(probabilities_class1)\n",
793
- " \n",
794
- " return np.concatenate(all_probs)\n",
795
- " \n",
796
- " \n",
797
- "dummy_labels=[np.zeros(len(test_one_letter_sequence))]\n",
798
- "# Replace uncommon amino acids with \"X\"\n",
799
- "test_one_letter_sequence = test_one_letter_sequence.replace(\"O\", \"X\").replace(\"B\", \"X\").replace(\"U\", \"X\").replace(\"Z\", \"X\").replace(\"J\", \"X\")\n",
800
- "\n",
801
- "# Add spaces between each amino acid for ProtT5 and ProstT5 models\n",
802
- "if \"Rostlab\" in checkpoint:\n",
803
- " test_one_letter_sequence = \" \".join(test_one_letter_sequence)\n",
804
- "\n",
805
- "# Add <AA2fold> for ProstT5 model input format\n",
806
- "if \"ProstT5\" in checkpoint:\n",
807
- " test_one_letter_sequence = \"<AA2fold> \" + test_one_letter_sequence\n",
808
- " \n",
809
- "test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)\n",
810
- "\n",
811
- "if (\"esm\" in checkpoint) or (\"ProstT5\" in checkpoint):\n",
812
- " data_collator = DataCollatorForTokenClassificationESM(tokenizer)\n",
813
- "else:\n",
814
- " data_collator = DataCollatorForTokenClassification(tokenizer)\n",
815
- "\n",
816
- "test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)\n",
817
- "\n",
818
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
819
- "model.to(device)\n",
820
- "for batch in test_loader:\n",
821
- " input_ids = batch['input_ids'].to(device)\n",
822
- " attention_mask = batch['attention_mask'].to(device)\n",
823
- " labels = batch['labels'] # Ensure to get labels from batch\n",
824
- "\n",
825
- " outputs = model(input_ids, attention_mask=attention_mask)\n",
826
- " logits = outputs.logits.detach().cpu().numpy()\n",
827
- "\n",
828
- "logits=convert_predictions(logits)\n",
829
- "logits.shape\n",
830
- "\n",
831
- "def normalize_scores(scores):\n",
832
- " min_score = np.min(scores)\n",
833
- " max_score = np.max(scores)\n",
834
- " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
835
- "\n",
836
- "normalized_scores = normalize_scores(logits)\n",
837
- "\n",
838
- "normalized_scores.shape"
839
- ]
840
- },
841
- {
842
- "cell_type": "code",
843
- "execution_count": null,
844
- "id": "58b5ae4d-9e8e-4d07-ab46-76d23cc29016",
845
- "metadata": {},
846
- "outputs": [],
847
- "source": []
848
- }
849
- ],
850
- "metadata": {
851
- "kernelspec": {
852
- "display_name": "Python [conda env:LLM] *",
853
- "language": "python",
854
- "name": "conda-env-LLM-py"
855
- },
856
- "language_info": {
857
- "codemirror_mode": {
858
- "name": "ipython",
859
- "version": 3
860
- },
861
- "file_extension": ".py",
862
- "mimetype": "text/x-python",
863
- "name": "python",
864
- "nbconvert_exporter": "python",
865
- "pygments_lexer": "ipython3",
866
- "version": "3.12.2"
867
- }
868
- },
869
- "nbformat": 4,
870
- "nbformat_minor": 5
871
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -10,5 +10,4 @@ sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
  gradio_molecule3d
13
- biopython>=1.81
14
- matplotlib
 
10
  huggingface_hub>=0.15.0
11
  requests
12
  gradio_molecule3d
13
+ biopython>=1.81
 
test.ipynb ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "1f8ea359-674c-4263-9c2a-7a8e7e464249",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7862\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": []
33
+ },
34
+ "execution_count": 3,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "import gradio as gr\n",
41
+ "import requests\n",
42
+ "from Bio.PDB import PDBParser\n",
43
+ "from gradio_molecule3d import Molecule3D\n",
44
+ "import numpy as np\n",
45
+ "\n",
46
+ "# Function to fetch a PDB file from RCSB PDB\n",
47
+ "def fetch_pdb(pdb_id):\n",
48
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
49
+ " pdb_path = f'{pdb_id}.pdb'\n",
50
+ " response = requests.get(pdb_url)\n",
51
+ " if response.status_code == 200:\n",
52
+ " with open(pdb_path, 'wb') as f:\n",
53
+ " f.write(response.content)\n",
54
+ " return pdb_path\n",
55
+ " else:\n",
56
+ " return None\n",
57
+ "\n",
58
+ "# Function to process the PDB file and return random predictions\n",
59
+ "def process_pdb(pdb_id, segment):\n",
60
+ " pdb_path = fetch_pdb(pdb_id)\n",
61
+ " if not pdb_path:\n",
62
+ " return \"Failed to fetch PDB file\", None, None\n",
63
+ "\n",
64
+ " parser = PDBParser(QUIET=True)\n",
65
+ " structure = parser.get_structure('protein', pdb_path)\n",
66
+ " \n",
67
+ " try:\n",
68
+ " chain = structure[0][segment]\n",
69
+ " except KeyError:\n",
70
+ " return \"Invalid Chain ID\", None, None\n",
71
+ "\n",
72
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
73
+ " random_scores = np.random.rand(len(sequence))\n",
74
+ "\n",
75
+ " result_str = \"\\n\".join(\n",
76
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
77
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
78
+ " )\n",
79
+ "\n",
80
+ " # Save the predictions to a file\n",
81
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
82
+ " with open(prediction_file, \"w\") as f:\n",
83
+ " f.write(result_str)\n",
84
+ " \n",
85
+ " return result_str, pdb_path, prediction_file\n",
86
+ "\n",
87
+ "#reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
88
+ "\n",
89
+ "reps = [\n",
90
+ " {\n",
91
+ " \"model\": 0,\n",
92
+ " \"style\": \"cartoon\",\n",
93
+ " \"color\": \"whiteCarbon\",\n",
94
+ " \"residue_range\": \"\",\n",
95
+ " \"around\": 0,\n",
96
+ " \"byres\": False,\n",
97
+ " },\n",
98
+ " {\n",
99
+ " \"model\": 0,\n",
100
+ " \"chain\": \"A\",\n",
101
+ " \"resname\": \"HIS\",\n",
102
+ " \"style\": \"stick\",\n",
103
+ " \"color\": \"red\"\n",
104
+ " }\n",
105
+ " ]\n",
106
+ "\n",
107
+ "\n",
108
+ "# Gradio UI\n",
109
+ "with gr.Blocks() as demo:\n",
110
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
111
+ "\n",
112
+ " with gr.Row():\n",
113
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
114
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
115
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
116
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
117
+ "\n",
118
+ " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
119
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
120
+ " download_output = gr.File(label=\"Download Predictions\")\n",
121
+ "\n",
122
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
123
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
124
+ "\n",
125
+ " gr.Markdown(\"## Examples\")\n",
126
+ " gr.Examples(\n",
127
+ " examples=[\n",
128
+ " [\"2IWI\", \"A\"],\n",
129
+ " [\"7RPZ\", \"B\"],\n",
130
+ " [\"3TJN\", \"C\"]\n",
131
+ " ],\n",
132
+ " inputs=[pdb_input, segment_input],\n",
133
+ " outputs=[predictions_output, molecule_output, download_output]\n",
134
+ " )\n",
135
+ "\n",
136
+ "demo.launch()"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "id": "bd50ff2e-ed03-498e-8af2-73c0fb8ea07e",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": []
146
+ },
147
+ {
148
+ "cell_type": "raw",
149
+ "id": "88affe12-7c48-4bd6-9e46-32cdffa729fe",
150
+ "metadata": {},
151
+ "source": [
152
+ "import gradio as gr\n",
153
+ "from gradio_molecule3d import Molecule3D\n",
154
+ "\n",
155
+ "\n",
156
+ "example = Molecule3D().example_value()\n",
157
+ "\n",
158
+ "\n",
159
+ "reps = [\n",
160
+ " {\n",
161
+ " \"model\": 0,\n",
162
+ " \"style\": \"cartoon\",\n",
163
+ " \"color\": \"whiteCarbon\",\n",
164
+ " \"residue_range\": \"\",\n",
165
+ " \"around\": 0,\n",
166
+ " \"byres\": False,\n",
167
+ " },\n",
168
+ " {\n",
169
+ " \"model\": 0,\n",
170
+ " \"chain\": \"A\",\n",
171
+ " \"resname\": \"HIS\",\n",
172
+ " \"style\": \"stick\",\n",
173
+ " \"color\": \"red\"\n",
174
+ " }\n",
175
+ " ]\n",
176
+ "\n",
177
+ "\n",
178
+ "\n",
179
+ "def predict(x):\n",
180
+ " print(\"predict function\", x)\n",
181
+ " print(x.name)\n",
182
+ " return x\n",
183
+ "\n",
184
+ "with gr.Blocks() as demo:\n",
185
+ " gr.Markdown(\"# Molecule3D\")\n",
186
+ " inp = Molecule3D(label=\"Molecule3D\", reps=reps)\n",
187
+ " out = Molecule3D(label=\"Output\", reps=reps)\n",
188
+ "\n",
189
+ " btn = gr.Button(\"Predict\")\n",
190
+ " gr.Markdown(\"\"\" \n",
191
+ " You can configure the default rendering of the molecule by adding a list of representations\n",
192
+ " <pre>\n",
193
+ " reps = [\n",
194
+ " {\n",
195
+ " \"model\": 0,\n",
196
+ " \"style\": \"cartoon\",\n",
197
+ " \"color\": \"whiteCarbon\",\n",
198
+ " \"residue_range\": \"\",\n",
199
+ " \"around\": 0,\n",
200
+ " \"byres\": False,\n",
201
+ " },\n",
202
+ " {\n",
203
+ " \"model\": 0,\n",
204
+ " \"chain\": \"A\",\n",
205
+ " \"resname\": \"HIS\",\n",
206
+ " \"style\": \"stick\",\n",
207
+ " \"color\": \"red\"\n",
208
+ " }\n",
209
+ " ]\n",
210
+ " </pre>\n",
211
+ " \"\"\")\n",
212
+ " btn.click(predict, inputs=inp, outputs=out)\n",
213
+ "\n",
214
+ "\n",
215
+ "if __name__ == \"__main__\":\n",
216
+ " demo.launch()"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "id": "d27cc368-26a0-42c2-a68a-8833de7bb4a0",
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": []
226
+ },
227
+ {
228
+ "cell_type": "raw",
229
+ "id": "2b970adb-3152-427f-bb58-b92974ff406e",
230
+ "metadata": {},
231
+ "source": [
232
+ "import gradio as gr\n",
233
+ "import os\n",
234
+ "import requests\n",
235
+ "from Bio.PDB import PDBParser, PDBIO\n",
236
+ "import biotite.structure.io as bsio\n",
237
+ "\n",
238
+ "def read_mol(pdb_path):\n",
239
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
240
+ " with open(pdb_path, 'r') as f:\n",
241
+ " return f.read()\n",
242
+ "\n",
243
+ "# Function to fetch or upload the PDB file\n",
244
+ "def get_pdb(pdb_code=\"\", filepath=\"\"):\n",
245
+ " if pdb_code and len(pdb_code) == 4:\n",
246
+ " pdb_file = f\"{pdb_code}.pdb\"\n",
247
+ " if not os.path.exists(pdb_file):\n",
248
+ " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
249
+ " return pdb_file\n",
250
+ " elif filepath is not None:\n",
251
+ " return filepath\n",
252
+ " else:\n",
253
+ " return None\n",
254
+ "\n",
255
+ "def molecule(input_pdb):\n",
256
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
257
+ " \n",
258
+ " html_content = f\"\"\"\n",
259
+ " <!DOCTYPE html>\n",
260
+ " <html>\n",
261
+ " <head> \n",
262
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
263
+ " <style>\n",
264
+ " .mol-container {{\n",
265
+ " width: 100%;\n",
266
+ " height: 700px;\n",
267
+ " position: relative;\n",
268
+ " }}\n",
269
+ " </style>\n",
270
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
271
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
272
+ " </head>\n",
273
+ " <body>\n",
274
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
275
+ " <script>\n",
276
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
277
+ " $(document).ready(function () {{\n",
278
+ " let element = $(\"#container\");\n",
279
+ " let config = {{ backgroundColor: \"white\" }};\n",
280
+ " let viewer = $3Dmol.createViewer(element, config);\n",
281
+ " viewer.addModel(pdb, \"pdb\");\n",
282
+ " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
283
+ " viewer.zoomTo();\n",
284
+ " viewer.render();\n",
285
+ " viewer.zoom(0.8, 2000);\n",
286
+ " }});\n",
287
+ " </script>\n",
288
+ " </body>\n",
289
+ " </html>\n",
290
+ " \"\"\"\n",
291
+ " \n",
292
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
293
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
294
+ "\n",
295
+ "# Gradio function to update the visualization\n",
296
+ "def update(inp, file):\n",
297
+ " pdb_path = get_pdb(inp, file)\n",
298
+ " if pdb_path:\n",
299
+ " return molecule(pdb_path)\n",
300
+ " else:\n",
301
+ " return \"Invalid input. Please provide a valid PDB code or upload a PDB file.\"\n",
302
+ "\n",
303
+ "# Gradio UI\n",
304
+ "demo = gr.Blocks()\n",
305
+ "with demo:\n",
306
+ " gr.Markdown(\"# PDB Viewer using 3Dmol.js\")\n",
307
+ " with gr.Row():\n",
308
+ " with gr.Column():\n",
309
+ " inp = gr.Textbox(\n",
310
+ " placeholder=\"PDB Code or upload file below\", label=\"Input structure\"\n",
311
+ " )\n",
312
+ " file = gr.File(file_count=\"single\")\n",
313
+ " btn = gr.Button(\"View structure\")\n",
314
+ " mol = gr.HTML()\n",
315
+ " btn.click(fn=update, inputs=[inp, file], outputs=mol)\n",
316
+ "\n",
317
+ "# Launch the Gradio interface \n",
318
+ "demo.launch(debug=True)"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "id": "ee215c16-a1fb-450f-bb93-37aaee6fb3f1",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": []
328
+ },
329
+ {
330
+ "cell_type": "raw",
331
+ "id": "050aa2e8-2dbe-4a28-8692-58ca7c50fccd",
332
+ "metadata": {},
333
+ "source": [
334
+ "import gradio as gr\n",
335
+ "import os\n",
336
+ "import requests\n",
337
+ "import numpy as np\n",
338
+ "from Bio.PDB import PDBParser\n",
339
+ "\n",
340
+ "def read_mol(pdb_path):\n",
341
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
342
+ " with open(pdb_path, 'r') as f:\n",
343
+ " return f.read()\n",
344
+ "\n",
345
+ "# Function to fetch a PDB file from RCSB PDB\n",
346
+ "def fetch_pdb(pdb_id):\n",
347
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
348
+ " pdb_path = f'{pdb_id}.pdb'\n",
349
+ " response = requests.get(pdb_url)\n",
350
+ " if response.status_code == 200:\n",
351
+ " with open(pdb_path, 'wb') as f:\n",
352
+ " f.write(response.content)\n",
353
+ " return molecule(pdb_path)\n",
354
+ " else:\n",
355
+ " return None\n",
356
+ "\n",
357
+ "# Function to process the PDB file and return random predictions\n",
358
+ "def process_pdb(pdb_id, segment):\n",
359
+ " pdb_path = fetch_pdb(pdb_id)\n",
360
+ " if not pdb_path:\n",
361
+ " return \"Failed to fetch PDB file\", None, None\n",
362
+ " \n",
363
+ " parser = PDBParser(QUIET=True)\n",
364
+ " structure = parser.get_structure('protein', pdb_path)\n",
365
+ " \n",
366
+ " try:\n",
367
+ " chain = structure[0][segment]\n",
368
+ " except KeyError:\n",
369
+ " return \"Invalid Chain ID\", None, None\n",
370
+ " \n",
371
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
372
+ " random_scores = np.random.rand(len(sequence))\n",
373
+ " result_str = \"\\n\".join(\n",
374
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
375
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
376
+ " )\n",
377
+ " \n",
378
+ " # Save the predictions to a file\n",
379
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
380
+ " with open(prediction_file, \"w\") as f:\n",
381
+ " f.write(result_str)\n",
382
+ " \n",
383
+ " return result_str, molecule(pdb_path), prediction_file\n",
384
+ "\n",
385
+ "def molecule(input_pdb):\n",
386
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
387
+ " \n",
388
+ " html_content = f\"\"\"\n",
389
+ " <!DOCTYPE html>\n",
390
+ " <html>\n",
391
+ " <head> \n",
392
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
393
+ " <style>\n",
394
+ " .mol-container {{\n",
395
+ " width: 100%;\n",
396
+ " height: 700px;\n",
397
+ " position: relative;\n",
398
+ " }}\n",
399
+ " </style>\n",
400
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
401
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
402
+ " </head>\n",
403
+ " <body>\n",
404
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
405
+ " <script>\n",
406
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
407
+ " $(document).ready(function () {{\n",
408
+ " let element = $(\"#container\");\n",
409
+ " let config = {{ backgroundColor: \"white\" }};\n",
410
+ " let viewer = $3Dmol.createViewer(element, config);\n",
411
+ " viewer.addModel(pdb, \"pdb\");\n",
412
+ " \n",
413
+ " // Set cartoon representation with white carbon color scheme\n",
414
+ " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
415
+ " \n",
416
+ " // Highlight specific histidine residues in red stick representation\n",
417
+ " viewer.getModel(0).setStyle(\n",
418
+ " {{\"resn\": \"HIS\"}}, \n",
419
+ " {{\"stick\": {{\"color\": \"red\"}}}}\n",
420
+ " );\n",
421
+ " \n",
422
+ " viewer.zoomTo();\n",
423
+ " viewer.render();\n",
424
+ " viewer.zoom(0.8, 2000);\n",
425
+ " }});\n",
426
+ " </script>\n",
427
+ " </body>\n",
428
+ " </html>\n",
429
+ " \"\"\"\n",
430
+ " \n",
431
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
432
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
433
+ "\n",
434
+ "# Gradio UI\n",
435
+ "with gr.Blocks() as demo:\n",
436
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
437
+ " with gr.Row():\n",
438
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
439
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
440
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
441
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
442
+ " \n",
443
+ " # Use HTML output instead of Molecule3D\n",
444
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
445
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
446
+ " download_output = gr.File(label=\"Download Predictions\")\n",
447
+ " \n",
448
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
449
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
450
+ " \n",
451
+ " gr.Markdown(\"## Examples\")\n",
452
+ " gr.Examples(\n",
453
+ " examples=[\n",
454
+ " [\"2IWI\", \"A\"],\n",
455
+ " [\"7RPZ\", \"B\"],\n",
456
+ " [\"3TJN\", \"C\"]\n",
457
+ " ],\n",
458
+ " inputs=[pdb_input, segment_input],\n",
459
+ " outputs=[predictions_output, molecule_output, download_output]\n",
460
+ " )\n",
461
+ "\n",
462
+ "demo.launch(debug=True)"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "id": "9a5facd9-855c-4b35-8dd3-2c0c8c7dd356",
469
+ "metadata": {},
470
+ "outputs": [],
471
+ "source": []
472
+ },
473
+ {
474
+ "cell_type": "raw",
475
+ "id": "a762170f-92a9-473d-b18d-53607a780e3b",
476
+ "metadata": {},
477
+ "source": [
478
+ "import gradio as gr\n",
479
+ "import requests\n",
480
+ "from Bio.PDB import PDBParser\n",
481
+ "import numpy as np\n",
482
+ "import os\n",
483
+ "\n",
484
+ "def read_mol(pdb_path):\n",
485
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
486
+ " with open(pdb_path, 'r') as f:\n",
487
+ " return f.read()\n",
488
+ "\n",
489
+ "# Function to fetch a PDB file from RCSB PDB\n",
490
+ "def fetch_pdb(pdb_id):\n",
491
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
492
+ " pdb_path = f'{pdb_id}.pdb'\n",
493
+ " response = requests.get(pdb_url)\n",
494
+ " if response.status_code == 200:\n",
495
+ " with open(pdb_path, 'wb') as f:\n",
496
+ " f.write(response.content)\n",
497
+ " return pdb_path\n",
498
+ " else:\n",
499
+ " return None\n",
500
+ "\n",
501
+ "# Function to process the PDB file and return random predictions\n",
502
+ "def process_pdb(pdb_id, segment):\n",
503
+ " pdb_path = fetch_pdb(pdb_id)\n",
504
+ " if not pdb_path:\n",
505
+ " return \"Failed to fetch PDB file\", None, None\n",
506
+ " parser = PDBParser(QUIET=True)\n",
507
+ " structure = parser.get_structure('protein', pdb_path)\n",
508
+ " \n",
509
+ " try:\n",
510
+ " chain = structure[0][segment]\n",
511
+ " except KeyError:\n",
512
+ " return \"Invalid Chain ID\", None, None\n",
513
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
514
+ " random_scores = np.random.rand(len(sequence))\n",
515
+ " result_str = \"\\n\".join(\n",
516
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
517
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
518
+ " )\n",
519
+ " # Save the predictions to a file\n",
520
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
521
+ " with open(prediction_file, \"w\") as f:\n",
522
+ " f.write(result_str)\n",
523
+ " \n",
524
+ " return result_str, molecule(pdb_path), prediction_file\n",
525
+ "\n",
526
+ "def molecule(input_pdb):\n",
527
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
528
+ " \n",
529
+ " html_content = f\"\"\"\n",
530
+ " <!DOCTYPE html>\n",
531
+ " <html>\n",
532
+ " <head> \n",
533
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
534
+ " <style>\n",
535
+ " .mol-container {{\n",
536
+ " width: 100%;\n",
537
+ " height: 700px;\n",
538
+ " position: relative;\n",
539
+ " }}\n",
540
+ " </style>\n",
541
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
542
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
543
+ " </head>\n",
544
+ " <body>\n",
545
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
546
+ " <script>\n",
547
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
548
+ " $(document).ready(function () {{\n",
549
+ " let element = $(\"#container\");\n",
550
+ " let config = {{ backgroundColor: \"white\" }};\n",
551
+ " let viewer = $3Dmol.createViewer(element, config);\n",
552
+ " viewer.addModel(pdb, \"pdb\");\n",
553
+ " \n",
554
+ " // Set cartoon representation with white carbon color scheme\n",
555
+ " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
556
+ " \n",
557
+ " // Highlight specific histidine residues in red stick representation\n",
558
+ " viewer.getModel(0).setStyle(\n",
559
+ " {{\"resn\": \"HIS\"}}, \n",
560
+ " {{\"stick\": {{\"color\": \"red\"}}}}\n",
561
+ " );\n",
562
+ " \n",
563
+ " viewer.zoomTo();\n",
564
+ " viewer.render();\n",
565
+ " viewer.zoom(0.8, 2000);\n",
566
+ " }});\n",
567
+ " </script>\n",
568
+ " </body>\n",
569
+ " </html>\n",
570
+ " \"\"\"\n",
571
+ " \n",
572
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
573
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
574
+ "\n",
575
+ "# Gradio UI\n",
576
+ "with gr.Blocks() as demo:\n",
577
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
578
+ " with gr.Row():\n",
579
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
580
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
581
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
582
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
583
+ " \n",
584
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
585
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
586
+ " download_output = gr.File(label=\"Download Predictions\")\n",
587
+ " \n",
588
+ " # Update to explicitly use molecule() function for visualization\n",
589
+ " visualize_btn.click(\n",
590
+ " fn=lambda pdb_id: molecule(fetch_pdb(pdb_id)), \n",
591
+ " inputs=[pdb_input], \n",
592
+ " outputs=molecule_output\n",
593
+ " )\n",
594
+ " \n",
595
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
596
+ " \n",
597
+ " gr.Markdown(\"## Examples\")\n",
598
+ " gr.Examples(\n",
599
+ " examples=[\n",
600
+ " [\"2IWI\", \"A\"],\n",
601
+ " [\"7RPZ\", \"B\"],\n",
602
+ " [\"3TJN\", \"C\"]\n",
603
+ " ],\n",
604
+ " inputs=[pdb_input, segment_input],\n",
605
+ " outputs=[predictions_output, molecule_output, download_output]\n",
606
+ " )\n",
607
+ "\n",
608
+ "demo.launch()"
609
+ ]
610
+ },
611
+ {
612
+ "cell_type": "code",
613
+ "execution_count": null,
614
+ "id": "15527a58-c449-4da0-8fab-3baaede15e41",
615
+ "metadata": {},
616
+ "outputs": [],
617
+ "source": []
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 2,
622
+ "id": "9ef3e330-cb88-4c29-b84a-2f8652883cfc",
623
+ "metadata": {},
624
+ "outputs": [
625
+ {
626
+ "name": "stdout",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "* Running on local URL: http://127.0.0.1:7860\n",
630
+ "\n",
631
+ "To create a public link, set `share=True` in `launch()`.\n"
632
+ ]
633
+ },
634
+ {
635
+ "data": {
636
+ "text/html": [
637
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
638
+ ],
639
+ "text/plain": [
640
+ "<IPython.core.display.HTML object>"
641
+ ]
642
+ },
643
+ "metadata": {},
644
+ "output_type": "display_data"
645
+ },
646
+ {
647
+ "data": {
648
+ "text/plain": []
649
+ },
650
+ "execution_count": 2,
651
+ "metadata": {},
652
+ "output_type": "execute_result"
653
+ }
654
+ ],
655
+ "source": [
656
+ "import gradio as gr\n",
657
+ "import requests\n",
658
+ "from Bio.PDB import PDBParser\n",
659
+ "import numpy as np\n",
660
+ "import os\n",
661
+ "from gradio_molecule3d import Molecule3D\n",
662
+ "\n",
663
+ "def read_mol(pdb_path):\n",
664
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
665
+ " with open(pdb_path, 'r') as f:\n",
666
+ " return f.read()\n",
667
+ "\n",
668
+ "def fetch_pdb(pdb_id):\n",
669
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
670
+ " pdb_path = f'{pdb_id}.pdb'\n",
671
+ " response = requests.get(pdb_url)\n",
672
+ " if response.status_code == 200:\n",
673
+ " with open(pdb_path, 'wb') as f:\n",
674
+ " f.write(response.content)\n",
675
+ " return pdb_path\n",
676
+ " else:\n",
677
+ " return None\n",
678
+ "\n",
679
+ "def process_pdb(pdb_id, segment):\n",
680
+ " pdb_path = fetch_pdb(pdb_id)\n",
681
+ " if not pdb_path:\n",
682
+ " return \"Failed to fetch PDB file\", None, None\n",
683
+ " parser = PDBParser(QUIET=True)\n",
684
+ " structure = parser.get_structure('protein', pdb_path)\n",
685
+ " \n",
686
+ " try:\n",
687
+ " chain = structure[0][segment]\n",
688
+ " except KeyError:\n",
689
+ " return \"Invalid Chain ID\", None, None\n",
690
+ " sequence = [residue.get_resname() for residue in chain if residue.id[0] == ' ']\n",
691
+ " random_scores = np.random.rand(len(sequence))\n",
692
+ " result_str = \"\\n\".join(\n",
693
+ " f\"{seq} {res.id[1]} {score:.2f}\" \n",
694
+ " for seq, res, score in zip(sequence, chain, random_scores)\n",
695
+ " )\n",
696
+ " # Save the predictions to a file\n",
697
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
698
+ " with open(prediction_file, \"w\") as f:\n",
699
+ " f.write(result_str)\n",
700
+ " \n",
701
+ " return result_str, molecule(pdb_path, random_scores), prediction_file\n",
702
+ "\n",
703
+ "def molecule(input_pdb, scores=None):\n",
704
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
705
+ " \n",
706
+ " # Prepare high-scoring residues script if scores are provided\n",
707
+ " high_score_script = \"\"\n",
708
+ " if scores is not None:\n",
709
+ " high_score_script = \"\"\"\n",
710
+ " // Highlight residues with high scores\n",
711
+ " let highScoreResidues = [{}];\n",
712
+ " viewer.getModel(0).setStyle(\n",
713
+ " {{\"resi\": highScoreResidues}}, \n",
714
+ " {{\"stick\": {{\"color\": \"red\"}}}}\n",
715
+ " );\n",
716
+ " \"\"\".format(\n",
717
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8)\n",
718
+ " )\n",
719
+ " \n",
720
+ " html_content = f\"\"\"\n",
721
+ " <!DOCTYPE html>\n",
722
+ " <html>\n",
723
+ " <head> \n",
724
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
725
+ " <style>\n",
726
+ " .mol-container {{\n",
727
+ " width: 100%;\n",
728
+ " height: 700px;\n",
729
+ " position: relative;\n",
730
+ " }}\n",
731
+ " </style>\n",
732
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
733
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
734
+ " </head>\n",
735
+ " <body>\n",
736
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
737
+ " <script>\n",
738
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
739
+ " $(document).ready(function () {{\n",
740
+ " let element = $(\"#container\");\n",
741
+ " let config = {{ backgroundColor: \"white\" }};\n",
742
+ " let viewer = $3Dmol.createViewer(element, config);\n",
743
+ " viewer.addModel(pdb, \"pdb\");\n",
744
+ " \n",
745
+ " // Set cartoon representation with white carbon color scheme\n",
746
+ " viewer.getModel(0).setStyle({{}}, {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }});\n",
747
+ " \n",
748
+ " {high_score_script}\n",
749
+ " \n",
750
+ " viewer.zoomTo();\n",
751
+ " viewer.render();\n",
752
+ " viewer.zoom(0.8, 2000);\n",
753
+ " }});\n",
754
+ " </script>\n",
755
+ " </body>\n",
756
+ " </html>\n",
757
+ " \"\"\"\n",
758
+ " \n",
759
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
760
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
761
+ "\n",
762
+ "reps = [\n",
763
+ " {\n",
764
+ " \"model\": 0,\n",
765
+ " \"style\": \"cartoon\",\n",
766
+ " \"color\": \"whiteCarbon\",\n",
767
+ " \"residue_range\": \"\",\n",
768
+ " \"around\": 0,\n",
769
+ " \"byres\": False,\n",
770
+ " }\n",
771
+ " ]\n",
772
+ "# Gradio UI\n",
773
+ "with gr.Blocks() as demo:\n",
774
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
775
+ " with gr.Row():\n",
776
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
777
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
778
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
779
+ " #prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
780
+ "\n",
781
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
782
+ "\n",
783
+ " with gr.Row():\n",
784
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
785
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
786
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
787
+ "\n",
788
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
789
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
790
+ " download_output = gr.File(label=\"Download Predictions\")\n",
791
+ " \n",
792
+ " #visualize_btn.click(\n",
793
+ " # fn=lambda pdb_id: molecule(fetch_pdb(pdb_id)), \n",
794
+ " # inputs=[pdb_input], \n",
795
+ " # outputs=molecule_output\n",
796
+ " #)\n",
797
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
798
+ " \n",
799
+ " \n",
800
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
801
+ " \n",
802
+ " gr.Markdown(\"## Examples\")\n",
803
+ " gr.Examples(\n",
804
+ " examples=[\n",
805
+ " [\"2IWI\", \"A\"],\n",
806
+ " [\"7RPZ\", \"B\"],\n",
807
+ " [\"3TJN\", \"C\"]\n",
808
+ " ],\n",
809
+ " inputs=[pdb_input, segment_input],\n",
810
+ " outputs=[predictions_output, molecule_output, download_output]\n",
811
+ " )\n",
812
+ "\n",
813
+ "demo.launch()"
814
+ ]
815
+ },
816
+ {
817
+ "cell_type": "code",
818
+ "execution_count": null,
819
+ "id": "14605615-8610-4d9e-841b-db7618cde844",
820
+ "metadata": {},
821
+ "outputs": [],
822
+ "source": []
823
+ }
824
+ ],
825
+ "metadata": {
826
+ "kernelspec": {
827
+ "display_name": "Python (LLM)",
828
+ "language": "python",
829
+ "name": "llm"
830
+ },
831
+ "language_info": {
832
+ "codemirror_mode": {
833
+ "name": "ipython",
834
+ "version": 3
835
+ },
836
+ "file_extension": ".py",
837
+ "mimetype": "text/x-python",
838
+ "name": "python",
839
+ "nbconvert_exporter": "python",
840
+ "pygments_lexer": "ipython3",
841
+ "version": "3.12.7"
842
+ }
843
+ },
844
+ "nbformat": 4,
845
+ "nbformat_minor": 5
846
+ }
test2.ipynb ADDED
@@ -0,0 +1,1193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "f3b7f6b0-6685-4a5c-9529-45e0ca905a3b",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7860\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": []
33
+ },
34
+ "execution_count": 2,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "import gradio as gr\n",
41
+ "import requests\n",
42
+ "from Bio.PDB import PDBParser\n",
43
+ "import numpy as np\n",
44
+ "import os\n",
45
+ "from gradio_molecule3d import Molecule3D\n",
46
+ "\n",
47
+ "def read_mol(pdb_path):\n",
48
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
49
+ " with open(pdb_path, 'r') as f:\n",
50
+ " return f.read()\n",
51
+ "\n",
52
+ "def fetch_pdb(pdb_id):\n",
53
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
54
+ " pdb_path = f'{pdb_id}.pdb'\n",
55
+ " response = requests.get(pdb_url)\n",
56
+ " if response.status_code == 200:\n",
57
+ " with open(pdb_path, 'wb') as f:\n",
58
+ " f.write(response.content)\n",
59
+ " return pdb_path\n",
60
+ " else:\n",
61
+ " return None\n",
62
+ "\n",
63
+ "def process_pdb(pdb_id, segment):\n",
64
+ " pdb_path = fetch_pdb(pdb_id)\n",
65
+ " if not pdb_path:\n",
66
+ " return \"Failed to fetch PDB file\", None, None\n",
67
+ " \n",
68
+ " parser = PDBParser(QUIET=1)\n",
69
+ " structure = parser.get_structure('protein', pdb_path)\n",
70
+ " \n",
71
+ " try:\n",
72
+ " chain = structure[0][segment]\n",
73
+ " except KeyError:\n",
74
+ " return \"Invalid Chain ID\", None, None\n",
75
+ " \n",
76
+ " # Comprehensive amino acid mapping\n",
77
+ " aa_dict = {\n",
78
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
79
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
80
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
81
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
82
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
83
+ " }\n",
84
+ " \n",
85
+ " # Exclude non-amino acid residues\n",
86
+ " sequence = [\n",
87
+ " residue for residue in chain \n",
88
+ " if residue.get_resname().strip() in aa_dict\n",
89
+ " ]\n",
90
+ " \n",
91
+ " random_scores = np.random.rand(len(sequence))\n",
92
+ " result_str = \"\\n\".join(\n",
93
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
94
+ " for res, score in zip(sequence, random_scores)\n",
95
+ " )\n",
96
+ " \n",
97
+ " # Save the predictions to a file\n",
98
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
99
+ " with open(prediction_file, \"w\") as f:\n",
100
+ " f.write(result_str)\n",
101
+ " \n",
102
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
103
+ "\n",
104
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
105
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
106
+ " \n",
107
+ " # Prepare high-scoring residues script if scores are provided\n",
108
+ " high_score_script = \"\"\n",
109
+ " if scores is not None:\n",
110
+ " high_score_script = \"\"\"\n",
111
+ " // Reset all styles first\n",
112
+ " viewer.getModel(0).setStyle({}, {});\n",
113
+ " \n",
114
+ " // Show only the selected chain\n",
115
+ " viewer.getModel(0).setStyle(\n",
116
+ " {\"chain\": \"%s\"}, \n",
117
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
118
+ " );\n",
119
+ " \n",
120
+ " // Highlight high-scoring residues only for the selected chain\n",
121
+ " let highScoreResidues = [%s];\n",
122
+ " viewer.getModel(0).setStyle(\n",
123
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
124
+ " {\"stick\": {\"color\": \"red\"}}\n",
125
+ " );\n",
126
+ " \"\"\" % (segment, \n",
127
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
128
+ " segment)\n",
129
+ " \n",
130
+ " html_content = f\"\"\"\n",
131
+ " <!DOCTYPE html>\n",
132
+ " <html>\n",
133
+ " <head> \n",
134
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
135
+ " <style>\n",
136
+ " .mol-container {{\n",
137
+ " width: 100%;\n",
138
+ " height: 700px;\n",
139
+ " position: relative;\n",
140
+ " }}\n",
141
+ " </style>\n",
142
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
143
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
144
+ " </head>\n",
145
+ " <body>\n",
146
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
147
+ " <script>\n",
148
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
149
+ " $(document).ready(function () {{\n",
150
+ " let element = $(\"#container\");\n",
151
+ " let config = {{ backgroundColor: \"white\" }};\n",
152
+ " let viewer = $3Dmol.createViewer(element, config);\n",
153
+ " viewer.addModel(pdb, \"pdb\");\n",
154
+ " \n",
155
+ " // Reset all styles and show only selected chain\n",
156
+ " viewer.getModel(0).setStyle(\n",
157
+ " {{\"chain\": \"{segment}\"}}, \n",
158
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
159
+ " );\n",
160
+ " \n",
161
+ " {high_score_script}\n",
162
+ " \n",
163
+ " viewer.zoomTo();\n",
164
+ " viewer.render();\n",
165
+ " viewer.zoom(0.8, 2000);\n",
166
+ " }});\n",
167
+ " </script>\n",
168
+ " </body>\n",
169
+ " </html>\n",
170
+ " \"\"\"\n",
171
+ " \n",
172
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
173
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
174
+ "\n",
175
+ "reps = [\n",
176
+ " {\n",
177
+ " \"model\": 0,\n",
178
+ " \"style\": \"cartoon\",\n",
179
+ " \"color\": \"whiteCarbon\",\n",
180
+ " \"residue_range\": \"\",\n",
181
+ " \"around\": 0,\n",
182
+ " \"byres\": False,\n",
183
+ " }\n",
184
+ " ]\n",
185
+ "# Gradio UI\n",
186
+ "with gr.Blocks() as demo:\n",
187
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
188
+ " with gr.Row():\n",
189
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
190
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
191
+ "\n",
192
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
193
+ "\n",
194
+ " with gr.Row():\n",
195
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
196
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
197
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
198
+ "\n",
199
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
200
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
201
+ " download_output = gr.File(label=\"Download Predictions\")\n",
202
+ " \n",
203
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
204
+ " \n",
205
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
206
+ " \n",
207
+ " gr.Markdown(\"## Examples\")\n",
208
+ " gr.Examples(\n",
209
+ " examples=[\n",
210
+ " [\"2IWI\", \"A\"],\n",
211
+ " [\"7RPZ\", \"B\"],\n",
212
+ " [\"3TJN\", \"C\"]\n",
213
+ " ],\n",
214
+ " inputs=[pdb_input, segment_input],\n",
215
+ " outputs=[predictions_output, molecule_output, download_output]\n",
216
+ " )\n",
217
+ "\n",
218
+ "demo.launch()"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 6,
224
+ "id": "28f8f28c-48d3-4e35-9766-3de9882179b5",
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "* Running on local URL: http://127.0.0.1:7864\n",
232
+ "\n",
233
+ "To create a public link, set `share=True` in `launch()`.\n"
234
+ ]
235
+ },
236
+ {
237
+ "data": {
238
+ "text/html": [
239
+ "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
240
+ ],
241
+ "text/plain": [
242
+ "<IPython.core.display.HTML object>"
243
+ ]
244
+ },
245
+ "metadata": {},
246
+ "output_type": "display_data"
247
+ },
248
+ {
249
+ "data": {
250
+ "text/plain": []
251
+ },
252
+ "execution_count": 6,
253
+ "metadata": {},
254
+ "output_type": "execute_result"
255
+ }
256
+ ],
257
+ "source": [
258
+ "import gradio as gr\n",
259
+ "import requests\n",
260
+ "from Bio.PDB import PDBParser\n",
261
+ "import numpy as np\n",
262
+ "import os\n",
263
+ "from gradio_molecule3d import Molecule3D\n",
264
+ "\n",
265
+ "def read_mol(pdb_path):\n",
266
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
267
+ " with open(pdb_path, 'r') as f:\n",
268
+ " return f.read()\n",
269
+ "\n",
270
+ "def fetch_pdb(pdb_id):\n",
271
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
272
+ " pdb_path = f'{pdb_id}.pdb'\n",
273
+ " response = requests.get(pdb_url)\n",
274
+ " if response.status_code == 200:\n",
275
+ " with open(pdb_path, 'wb') as f:\n",
276
+ " f.write(response.content)\n",
277
+ " return pdb_path\n",
278
+ " else:\n",
279
+ " return None\n",
280
+ "\n",
281
+ "def process_pdb(pdb_id, segment):\n",
282
+ " pdb_path = fetch_pdb(pdb_id)\n",
283
+ " if not pdb_path:\n",
284
+ " return \"Failed to fetch PDB file\", None, None\n",
285
+ " \n",
286
+ " parser = PDBParser(QUIET=1)\n",
287
+ " structure = parser.get_structure('protein', pdb_path)\n",
288
+ " \n",
289
+ " try:\n",
290
+ " chain = structure[0][segment]\n",
291
+ " except KeyError:\n",
292
+ " return \"Invalid Chain ID\", None, None\n",
293
+ " \n",
294
+ " # Comprehensive amino acid mapping\n",
295
+ " aa_dict = {\n",
296
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
297
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
298
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
299
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
300
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
301
+ " }\n",
302
+ " \n",
303
+ " # Exclude non-amino acid residues\n",
304
+ " sequence = [\n",
305
+ " residue for residue in chain \n",
306
+ " if residue.get_resname().strip() in aa_dict\n",
307
+ " ]\n",
308
+ " \n",
309
+ " random_scores = np.random.rand(len(sequence))\n",
310
+ " result_str = \"\\n\".join(\n",
311
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
312
+ " for res, score in zip(sequence, random_scores)\n",
313
+ " )\n",
314
+ " \n",
315
+ " # Save the predictions to a file\n",
316
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
317
+ " with open(prediction_file, \"w\") as f:\n",
318
+ " f.write(result_str)\n",
319
+ " \n",
320
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
321
+ "\n",
322
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
323
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
324
+ " \n",
325
+ " # Prepare high-scoring residues script if scores are provided\n",
326
+ " high_score_script = \"\"\n",
327
+ " if scores is not None:\n",
328
+ " high_score_script = \"\"\"\n",
329
+ " // Reset all styles first\n",
330
+ " viewer.getModel(0).setStyle({}, {});\n",
331
+ " \n",
332
+ " // Show only the selected chain\n",
333
+ " viewer.getModel(0).setStyle(\n",
334
+ " {\"chain\": \"%s\"}, \n",
335
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
336
+ " );\n",
337
+ " \n",
338
+ " // Highlight high-scoring residues only for the selected chain\n",
339
+ " let highScoreResidues = [%s];\n",
340
+ " viewer.getModel(0).setStyle(\n",
341
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
342
+ " {\"stick\": {\"color\": \"red\"}}\n",
343
+ " );\n",
344
+ " \"\"\" % (segment, \n",
345
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
346
+ " segment)\n",
347
+ " \n",
348
+ " html_content = f\"\"\"\n",
349
+ " <!DOCTYPE html>\n",
350
+ " <html>\n",
351
+ " <head> \n",
352
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
353
+ " <style>\n",
354
+ " .mol-container {{\n",
355
+ " width: 100%;\n",
356
+ " height: 700px;\n",
357
+ " position: relative;\n",
358
+ " }}\n",
359
+ " </style>\n",
360
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
361
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
362
+ " </head>\n",
363
+ " <body>\n",
364
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
365
+ " <script>\n",
366
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
367
+ " $(document).ready(function () {{\n",
368
+ " let element = $(\"#container\");\n",
369
+ " let config = {{ backgroundColor: \"white\" }};\n",
370
+ " let viewer = $3Dmol.createViewer(element, config);\n",
371
+ " viewer.addModel(pdb, \"pdb\");\n",
372
+ " \n",
373
+ " // Reset all styles and show only selected chain\n",
374
+ " viewer.getModel(0).setStyle(\n",
375
+ " {{\"chain\": \"{segment}\"}}, \n",
376
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
377
+ " );\n",
378
+ " \n",
379
+ " {high_score_script}\n",
380
+ " \n",
381
+ " // Add hover functionality\n",
382
+ " viewer.setHoverable(\n",
383
+ " {{}}, \n",
384
+ " true, \n",
385
+ " function(atom, viewer, event, container) {{\n",
386
+ " if (!atom.label) {{\n",
387
+ " atom.label = viewer.addLabel(\n",
388
+ " atom.resn + \":\" + atom.atom, \n",
389
+ " {{\n",
390
+ " position: atom, \n",
391
+ " backgroundColor: 'mintcream', \n",
392
+ " fontColor: 'black',\n",
393
+ " fontSize: 12,\n",
394
+ " padding: 2\n",
395
+ " }}\n",
396
+ " );\n",
397
+ " }}\n",
398
+ " }},\n",
399
+ " function(atom, viewer) {{\n",
400
+ " if (atom.label) {{\n",
401
+ " viewer.removeLabel(atom.label);\n",
402
+ " delete atom.label;\n",
403
+ " }}\n",
404
+ " }}\n",
405
+ " );\n",
406
+ " \n",
407
+ " viewer.zoomTo();\n",
408
+ " viewer.render();\n",
409
+ " viewer.zoom(0.8, 2000);\n",
410
+ " }});\n",
411
+ " </script>\n",
412
+ " </body>\n",
413
+ " </html>\n",
414
+ " \"\"\"\n",
415
+ " \n",
416
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
417
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
418
+ "\n",
419
+ "reps = [\n",
420
+ " {\n",
421
+ " \"model\": 0,\n",
422
+ " \"style\": \"cartoon\",\n",
423
+ " \"color\": \"whiteCarbon\",\n",
424
+ " \"residue_range\": \"\",\n",
425
+ " \"around\": 0,\n",
426
+ " \"byres\": False,\n",
427
+ " }\n",
428
+ " ]\n",
429
+ "\n",
430
+ "# Gradio UI\n",
431
+ "with gr.Blocks() as demo:\n",
432
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
433
+ " with gr.Row():\n",
434
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
435
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
436
+ "\n",
437
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
438
+ "\n",
439
+ " with gr.Row():\n",
440
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
441
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
442
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
443
+ "\n",
444
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
445
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
446
+ " download_output = gr.File(label=\"Download Predictions\")\n",
447
+ " \n",
448
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
449
+ " \n",
450
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
451
+ " \n",
452
+ " gr.Markdown(\"## Examples\")\n",
453
+ " gr.Examples(\n",
454
+ " examples=[\n",
455
+ " [\"2IWI\", \"A\"],\n",
456
+ " [\"7RPZ\", \"B\"],\n",
457
+ " [\"3TJN\", \"C\"]\n",
458
+ " ],\n",
459
+ " inputs=[pdb_input, segment_input],\n",
460
+ " outputs=[predictions_output, molecule_output, download_output]\n",
461
+ " )\n",
462
+ "\n",
463
+ "demo.launch()"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "id": "517a2fe7-419f-4d0b-a9ed-62a22c1c1284",
470
+ "metadata": {},
471
+ "outputs": [],
472
+ "source": []
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": 11,
477
+ "id": "d62be1b5-762e-4b69-aed4-e4ba2a44482f",
478
+ "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "* Running on local URL: http://127.0.0.1:7867\n",
485
+ "\n",
486
+ "To create a public link, set `share=True` in `launch()`.\n"
487
+ ]
488
+ },
489
+ {
490
+ "data": {
491
+ "text/html": [
492
+ "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
+ ],
494
+ "text/plain": [
495
+ "<IPython.core.display.HTML object>"
496
+ ]
497
+ },
498
+ "metadata": {},
499
+ "output_type": "display_data"
500
+ },
501
+ {
502
+ "data": {
503
+ "text/plain": []
504
+ },
505
+ "execution_count": 11,
506
+ "metadata": {},
507
+ "output_type": "execute_result"
508
+ }
509
+ ],
510
+ "source": [
511
+ "import gradio as gr\n",
512
+ "import requests\n",
513
+ "from Bio.PDB import PDBParser\n",
514
+ "import numpy as np\n",
515
+ "import os\n",
516
+ "from gradio_molecule3d import Molecule3D\n",
517
+ "\n",
518
+ "def read_mol(pdb_path):\n",
519
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
520
+ " with open(pdb_path, 'r') as f:\n",
521
+ " return f.read()\n",
522
+ "\n",
523
+ "def fetch_pdb(pdb_id):\n",
524
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
525
+ " pdb_path = f'{pdb_id}.pdb'\n",
526
+ " response = requests.get(pdb_url)\n",
527
+ " if response.status_code == 200:\n",
528
+ " with open(pdb_path, 'wb') as f:\n",
529
+ " f.write(response.content)\n",
530
+ " return pdb_path\n",
531
+ " else:\n",
532
+ " return None\n",
533
+ "\n",
534
+ "def process_pdb(pdb_id, segment):\n",
535
+ " pdb_path = fetch_pdb(pdb_id)\n",
536
+ " if not pdb_path:\n",
537
+ " return \"Failed to fetch PDB file\", None, None\n",
538
+ " \n",
539
+ " parser = PDBParser(QUIET=1)\n",
540
+ " structure = parser.get_structure('protein', pdb_path)\n",
541
+ " \n",
542
+ " try:\n",
543
+ " chain = structure[0][segment]\n",
544
+ " except KeyError:\n",
545
+ " return \"Invalid Chain ID\", None, None\n",
546
+ " \n",
547
+ " # Comprehensive amino acid mapping\n",
548
+ " aa_dict = {\n",
549
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
550
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
551
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
552
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
553
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
554
+ " }\n",
555
+ " \n",
556
+ " # Exclude non-amino acid residues\n",
557
+ " sequence = [\n",
558
+ " residue for residue in chain \n",
559
+ " if residue.get_resname().strip() in aa_dict\n",
560
+ " ]\n",
561
+ " \n",
562
+ " random_scores = np.random.rand(len(sequence))\n",
563
+ " result_str = \"\\n\".join(\n",
564
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
565
+ " for res, score in zip(sequence, random_scores)\n",
566
+ " )\n",
567
+ " \n",
568
+ " # Save the predictions to a file\n",
569
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
570
+ " with open(prediction_file, \"w\") as f:\n",
571
+ " f.write(result_str)\n",
572
+ " \n",
573
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
574
+ "\n",
575
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
576
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
577
+ " \n",
578
+ " # Prepare high-scoring residues script if scores are provided\n",
579
+ " high_score_script = \"\"\n",
580
+ " if scores is not None:\n",
581
+ " high_score_script = \"\"\"\n",
582
+ " // Reset all styles first\n",
583
+ " viewer.getModel(0).setStyle({}, {});\n",
584
+ " \n",
585
+ " // Show only the selected chain\n",
586
+ " viewer.getModel(0).setStyle(\n",
587
+ " {\"chain\": \"%s\"}, \n",
588
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
589
+ " );\n",
590
+ " \n",
591
+ " // Highlight high-scoring residues only for the selected chain\n",
592
+ " let highScoreResidues = [%s];\n",
593
+ " viewer.getModel(0).setStyle(\n",
594
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
595
+ " {\"stick\": {\"color\": \"red\"}}\n",
596
+ " );\n",
597
+ "\n",
598
+ " // Highlight high-scoring residues only for the selected chain\n",
599
+ " let highScoreResidues2 = [%s];\n",
600
+ " viewer.getModel(0).setStyle(\n",
601
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
602
+ " {\"stick\": {\"color\": \"orange\"}}\n",
603
+ " );\n",
604
+ " \"\"\" % (segment, \n",
605
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
606
+ " segment,\n",
607
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
608
+ " segment)\n",
609
+ " \n",
610
+ " html_content = f\"\"\"\n",
611
+ " <!DOCTYPE html>\n",
612
+ " <html>\n",
613
+ " <head> \n",
614
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
615
+ " <style>\n",
616
+ " .mol-container {{\n",
617
+ " width: 100%;\n",
618
+ " height: 700px;\n",
619
+ " position: relative;\n",
620
+ " }}\n",
621
+ " </style>\n",
622
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
623
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
624
+ " </head>\n",
625
+ " <body>\n",
626
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
627
+ " <script>\n",
628
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
629
+ " $(document).ready(function () {{\n",
630
+ " let element = $(\"#container\");\n",
631
+ " let config = {{ backgroundColor: \"white\" }};\n",
632
+ " let viewer = $3Dmol.createViewer(element, config);\n",
633
+ " viewer.addModel(pdb, \"pdb\");\n",
634
+ " \n",
635
+ " // Reset all styles and show only selected chain\n",
636
+ " viewer.getModel(0).setStyle(\n",
637
+ " {{\"chain\": \"{segment}\"}}, \n",
638
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
639
+ " );\n",
640
+ " \n",
641
+ " {high_score_script}\n",
642
+ " \n",
643
+ " // Add hover functionality\n",
644
+ " viewer.setHoverable(\n",
645
+ " {{}}, \n",
646
+ " true, \n",
647
+ " function(atom, viewer, event, container) {{\n",
648
+ " if (!atom.label) {{\n",
649
+ " atom.label = viewer.addLabel(\n",
650
+ " atom.resn + \":\" + atom.atom, \n",
651
+ " {{\n",
652
+ " position: atom, \n",
653
+ " backgroundColor: 'mintcream', \n",
654
+ " fontColor: 'black',\n",
655
+ " fontSize: 12,\n",
656
+ " padding: 2\n",
657
+ " }}\n",
658
+ " );\n",
659
+ " }}\n",
660
+ " }},\n",
661
+ " function(atom, viewer) {{\n",
662
+ " if (atom.label) {{\n",
663
+ " viewer.removeLabel(atom.label);\n",
664
+ " delete atom.label;\n",
665
+ " }}\n",
666
+ " }}\n",
667
+ " );\n",
668
+ " \n",
669
+ " viewer.zoomTo();\n",
670
+ " viewer.render();\n",
671
+ " viewer.zoom(0.8, 2000);\n",
672
+ " }});\n",
673
+ " </script>\n",
674
+ " </body>\n",
675
+ " </html>\n",
676
+ " \"\"\"\n",
677
+ " \n",
678
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
679
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
680
+ "\n",
681
+ "reps = [\n",
682
+ " {\n",
683
+ " \"model\": 0,\n",
684
+ " \"style\": \"cartoon\",\n",
685
+ " \"color\": \"whiteCarbon\",\n",
686
+ " \"residue_range\": \"\",\n",
687
+ " \"around\": 0,\n",
688
+ " \"byres\": False,\n",
689
+ " }\n",
690
+ " ]\n",
691
+ "\n",
692
+ "# Gradio UI\n",
693
+ "with gr.Blocks() as demo:\n",
694
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
695
+ " with gr.Row():\n",
696
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
697
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
698
+ "\n",
699
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
700
+ "\n",
701
+ " with gr.Row():\n",
702
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
703
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
704
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
705
+ "\n",
706
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
707
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
708
+ " download_output = gr.File(label=\"Download Predictions\")\n",
709
+ " \n",
710
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
711
+ " \n",
712
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
713
+ " \n",
714
+ " gr.Markdown(\"## Examples\")\n",
715
+ " gr.Examples(\n",
716
+ " examples=[\n",
717
+ " [\"2IWI\", \"A\"],\n",
718
+ " [\"7RPZ\", \"B\"],\n",
719
+ " [\"3TJN\", \"C\"]\n",
720
+ " ],\n",
721
+ " inputs=[pdb_input, segment_input],\n",
722
+ " outputs=[predictions_output, molecule_output, download_output]\n",
723
+ " )\n",
724
+ "\n",
725
+ "demo.launch()"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": null,
731
+ "id": "30f35243-852f-4771-9a4b-5cdd198552b5",
732
+ "metadata": {},
733
+ "outputs": [],
734
+ "source": []
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "execution_count": null,
739
+ "id": "5eca6754-4aa1-463f-881a-25d2a0d6bb5b",
740
+ "metadata": {},
741
+ "outputs": [],
742
+ "source": [
743
+ "import gradio as gr\n",
744
+ "import requests\n",
745
+ "from Bio.PDB import PDBParser\n",
746
+ "import numpy as np\n",
747
+ "import os\n",
748
+ "from gradio_molecule3d import Molecule3D\n",
749
+ "\n",
750
+ "\n",
751
+ "from model_loader import load_model\n",
752
+ "\n",
753
+ "import torch\n",
754
+ "import torch.nn as nn\n",
755
+ "import torch.nn.functional as F\n",
756
+ "from torch.utils.data import DataLoader\n",
757
+ "\n",
758
+ "import re\n",
759
+ "import pandas as pd\n",
760
+ "import copy\n",
761
+ "\n",
762
+ "import transformers, datasets\n",
763
+ "from transformers import AutoTokenizer\n",
764
+ "from transformers import DataCollatorForTokenClassification\n",
765
+ "\n",
766
+ "from datasets import Dataset\n",
767
+ "\n",
768
+ "from scipy.special import expit\n",
769
+ "\n",
770
+ "# Load model and move to device\n",
771
+ "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
772
+ "max_length = 1500\n",
773
+ "model, tokenizer = load_model(checkpoint, max_length)\n",
774
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
775
+ "model.to(device)\n",
776
+ "model.eval()\n",
777
+ "\n",
778
+ "def normalize_scores(scores):\n",
779
+ " min_score = np.min(scores)\n",
780
+ " max_score = np.max(scores)\n",
781
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
782
+ " \n",
783
+ "def read_mol(pdb_path):\n",
784
+ " \"\"\"Read PDB file and return its content as a string\"\"\"\n",
785
+ " with open(pdb_path, 'r') as f:\n",
786
+ " return f.read()\n",
787
+ "\n",
788
+ "def fetch_pdb(pdb_id):\n",
789
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
790
+ " pdb_path = f'{pdb_id}.pdb'\n",
791
+ " response = requests.get(pdb_url)\n",
792
+ " if response.status_code == 200:\n",
793
+ " with open(pdb_path, 'wb') as f:\n",
794
+ " f.write(response.content)\n",
795
+ " return pdb_path\n",
796
+ " else:\n",
797
+ " return None\n",
798
+ "\n",
799
+ "def process_pdb(pdb_id, segment):\n",
800
+ " pdb_path = fetch_pdb(pdb_id)\n",
801
+ " if not pdb_path:\n",
802
+ " return \"Failed to fetch PDB file\", None, None\n",
803
+ " \n",
804
+ " parser = PDBParser(QUIET=1)\n",
805
+ " structure = parser.get_structure('protein', pdb_path)\n",
806
+ " \n",
807
+ " try:\n",
808
+ " chain = structure[0][segment]\n",
809
+ " except KeyError:\n",
810
+ " return \"Invalid Chain ID\", None, None\n",
811
+ " \n",
812
+ " # Comprehensive amino acid mapping\n",
813
+ " aa_dict = {\n",
814
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
815
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
816
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
817
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
818
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
819
+ " }\n",
820
+ " \n",
821
+ " # Exclude non-amino acid residues\n",
822
+ " sequence = [\n",
823
+ " residue for residue in chain \n",
824
+ " if residue.get_resname().strip() in aa_dict\n",
825
+ " ]\n",
826
+ " \n",
827
+ " # Prepare input for model prediction\n",
828
+ " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n",
829
+ " with torch.no_grad():\n",
830
+ " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n",
831
+ "\n",
832
+ " # Calculate scores and normalize them\n",
833
+ " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
834
+ " normalized_scores = normalize_scores(scores)\n",
835
+ "\n",
836
+ " result_str = \"\\n\".join(\n",
837
+ " f\"{aa_dict[res.get_resname()]} {res.id[1]} {score:.2f}\" \n",
838
+ " for res, score in zip(sequence, normalized_scores)\n",
839
+ " )\n",
840
+ " \n",
841
+ " # Save the predictions to a file\n",
842
+ " prediction_file = f\"{pdb_id}_predictions.txt\"\n",
843
+ " with open(prediction_file, \"w\") as f:\n",
844
+ " f.write(result_str)\n",
845
+ " \n",
846
+ " return result_str, molecule(pdb_path, random_scores, segment), prediction_file\n",
847
+ "\n",
848
+ "def molecule(input_pdb, scores=None, segment='A'):\n",
849
+ " mol = read_mol(input_pdb) # Read PDB file content\n",
850
+ " \n",
851
+ " # Prepare high-scoring residues script if scores are provided\n",
852
+ " high_score_script = \"\"\n",
853
+ " if scores is not None:\n",
854
+ " high_score_script = \"\"\"\n",
855
+ " // Reset all styles first\n",
856
+ " viewer.getModel(0).setStyle({}, {});\n",
857
+ " \n",
858
+ " // Show only the selected chain\n",
859
+ " viewer.getModel(0).setStyle(\n",
860
+ " {\"chain\": \"%s\"}, \n",
861
+ " { cartoon: {colorscheme:\"whiteCarbon\"} }\n",
862
+ " );\n",
863
+ " \n",
864
+ " // Highlight high-scoring residues only for the selected chain\n",
865
+ " let highScoreResidues = [%s];\n",
866
+ " viewer.getModel(0).setStyle(\n",
867
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues}, \n",
868
+ " {\"stick\": {\"color\": \"red\"}}\n",
869
+ " );\n",
870
+ "\n",
871
+ " // Highlight high-scoring residues only for the selected chain\n",
872
+ " let highScoreResidues2 = [%s];\n",
873
+ " viewer.getModel(0).setStyle(\n",
874
+ " {\"chain\": \"%s\", \"resi\": highScoreResidues2}, \n",
875
+ " {\"stick\": {\"color\": \"orange\"}}\n",
876
+ " );\n",
877
+ " \"\"\" % (segment, \n",
878
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if score > 0.8),\n",
879
+ " segment,\n",
880
+ " \", \".join(str(i+1) for i, score in enumerate(scores) if (score > 0.5) and (score < 0.8)),\n",
881
+ " segment)\n",
882
+ " \n",
883
+ " html_content = f\"\"\"\n",
884
+ " <!DOCTYPE html>\n",
885
+ " <html>\n",
886
+ " <head> \n",
887
+ " <meta http-equiv=\"content-type\" content=\"text/html; charset=UTF-8\" />\n",
888
+ " <style>\n",
889
+ " .mol-container {{\n",
890
+ " width: 100%;\n",
891
+ " height: 700px;\n",
892
+ " position: relative;\n",
893
+ " }}\n",
894
+ " </style>\n",
895
+ " <script src=\"https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js\"></script>\n",
896
+ " <script src=\"https://3Dmol.csb.pitt.edu/build/3Dmol-min.js\"></script>\n",
897
+ " </head>\n",
898
+ " <body>\n",
899
+ " <div id=\"container\" class=\"mol-container\"></div>\n",
900
+ " <script>\n",
901
+ " let pdb = `{mol}`; // Use template literal to properly escape PDB content\n",
902
+ " $(document).ready(function () {{\n",
903
+ " let element = $(\"#container\");\n",
904
+ " let config = {{ backgroundColor: \"white\" }};\n",
905
+ " let viewer = $3Dmol.createViewer(element, config);\n",
906
+ " viewer.addModel(pdb, \"pdb\");\n",
907
+ " \n",
908
+ " // Reset all styles and show only selected chain\n",
909
+ " viewer.getModel(0).setStyle(\n",
910
+ " {{\"chain\": \"{segment}\"}}, \n",
911
+ " {{ cartoon: {{ colorscheme:\"whiteCarbon\" }} }}\n",
912
+ " );\n",
913
+ " \n",
914
+ " {high_score_script}\n",
915
+ " \n",
916
+ " // Add hover functionality\n",
917
+ " viewer.setHoverable(\n",
918
+ " {{}}, \n",
919
+ " true, \n",
920
+ " function(atom, viewer, event, container) {{\n",
921
+ " if (!atom.label) {{\n",
922
+ " atom.label = viewer.addLabel(\n",
923
+ " atom.resn + \":\" + atom.atom, \n",
924
+ " {{\n",
925
+ " position: atom, \n",
926
+ " backgroundColor: 'mintcream', \n",
927
+ " fontColor: 'black',\n",
928
+ " fontSize: 12,\n",
929
+ " padding: 2\n",
930
+ " }}\n",
931
+ " );\n",
932
+ " }}\n",
933
+ " }},\n",
934
+ " function(atom, viewer) {{\n",
935
+ " if (atom.label) {{\n",
936
+ " viewer.removeLabel(atom.label);\n",
937
+ " delete atom.label;\n",
938
+ " }}\n",
939
+ " }}\n",
940
+ " );\n",
941
+ " \n",
942
+ " viewer.zoomTo();\n",
943
+ " viewer.render();\n",
944
+ " viewer.zoom(0.8, 2000);\n",
945
+ " }});\n",
946
+ " </script>\n",
947
+ " </body>\n",
948
+ " </html>\n",
949
+ " \"\"\"\n",
950
+ " \n",
951
+ " # Return the HTML content within an iframe safely encoded for special characters\n",
952
+ " return f'<iframe width=\"100%\" height=\"700\" srcdoc=\"{html_content.replace(chr(34), \"&quot;\").replace(chr(39), \"&#39;\")}\"></iframe>'\n",
953
+ "\n",
954
+ "reps = [\n",
955
+ " {\n",
956
+ " \"model\": 0,\n",
957
+ " \"style\": \"cartoon\",\n",
958
+ " \"color\": \"whiteCarbon\",\n",
959
+ " \"residue_range\": \"\",\n",
960
+ " \"around\": 0,\n",
961
+ " \"byres\": False,\n",
962
+ " }\n",
963
+ " ]\n",
964
+ "\n",
965
+ "# Gradio UI\n",
966
+ "with gr.Blocks() as demo:\n",
967
+ " gr.Markdown(\"# Protein Binding Site Prediction (Random Scores)\")\n",
968
+ " with gr.Row():\n",
969
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
970
+ " visualize_btn = gr.Button(\"Visualize Structure\")\n",
971
+ "\n",
972
+ " molecule_output2 = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
973
+ "\n",
974
+ " with gr.Row():\n",
975
+ " pdb_input = gr.Textbox(value=\"2IWI\", label=\"PDB ID\", placeholder=\"Enter PDB ID here...\")\n",
976
+ " segment_input = gr.Textbox(value=\"A\", label=\"Chain ID\", placeholder=\"Enter Chain ID here...\")\n",
977
+ " prediction_btn = gr.Button(\"Predict Random Binding Site Scores\")\n",
978
+ "\n",
979
+ " molecule_output = gr.HTML(label=\"Protein Structure\")\n",
980
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
981
+ " download_output = gr.File(label=\"Download Predictions\")\n",
982
+ " \n",
983
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)\n",
984
+ " \n",
985
+ " prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])\n",
986
+ " \n",
987
+ " gr.Markdown(\"## Examples\")\n",
988
+ " gr.Examples(\n",
989
+ " examples=[\n",
990
+ " [\"2IWI\", \"A\"],\n",
991
+ " [\"7RPZ\", \"B\"],\n",
992
+ " [\"3TJN\", \"C\"]\n",
993
+ " ],\n",
994
+ " inputs=[pdb_input, segment_input],\n",
995
+ " outputs=[predictions_output, molecule_output, download_output]\n",
996
+ " )\n",
997
+ "\n",
998
+ "demo.launch()"
999
+ ]
1000
+ },
1001
+ {
1002
+ "cell_type": "code",
1003
+ "execution_count": null,
1004
+ "id": "95046d1c-ec7c-4e3e-8a98-1802cb09a25b",
1005
+ "metadata": {},
1006
+ "outputs": [],
1007
+ "source": []
1008
+ },
1009
+ {
1010
+ "cell_type": "code",
1011
+ "execution_count": null,
1012
+ "id": "a37cbe6f-d57f-41e5-8ae1-38258da39d47",
1013
+ "metadata": {},
1014
+ "outputs": [],
1015
+ "source": [
1016
+ "import gradio as gr\n",
1017
+ "from model_loader import load_model\n",
1018
+ "\n",
1019
+ "import torch\n",
1020
+ "import torch.nn as nn\n",
1021
+ "import torch.nn.functional as F\n",
1022
+ "from torch.utils.data import DataLoader\n",
1023
+ "\n",
1024
+ "import re\n",
1025
+ "import numpy as np\n",
1026
+ "import os\n",
1027
+ "import pandas as pd\n",
1028
+ "import copy\n",
1029
+ "\n",
1030
+ "import transformers, datasets\n",
1031
+ "from transformers import AutoTokenizer\n",
1032
+ "from transformers import DataCollatorForTokenClassification\n",
1033
+ "\n",
1034
+ "from datasets import Dataset\n",
1035
+ "\n",
1036
+ "from scipy.special import expit\n",
1037
+ "\n",
1038
+ "import requests\n",
1039
+ "\n",
1040
+ "from gradio_molecule3d import Molecule3D\n",
1041
+ "\n",
1042
+ "# Biopython imports\n",
1043
+ "from Bio.PDB import PDBParser, Select, PDBIO\n",
1044
+ "from Bio.PDB.DSSP import DSSP\n",
1045
+ "from Bio.PDB import PDBList\n",
1046
+ "\n",
1047
+ "from matplotlib import cm # For color mapping\n",
1048
+ "from matplotlib.colors import Normalize\n",
1049
+ "\n",
1050
+ "# Load model and move to device\n",
1051
+ "checkpoint = 'ThorbenF/prot_t5_xl_uniref50'\n",
1052
+ "max_length = 1500\n",
1053
+ "model, tokenizer = load_model(checkpoint, max_length)\n",
1054
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1055
+ "model.to(device)\n",
1056
+ "model.eval()\n",
1057
+ "\n",
1058
+ "# Function to fetch a PDB file\n",
1059
+ "def fetch_pdb(pdb_id):\n",
1060
+ " pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'\n",
1061
+ " pdb_path = f'pdb_files/{pdb_id}.pdb'\n",
1062
+ " os.makedirs('pdb_files', exist_ok=True)\n",
1063
+ " response = requests.get(pdb_url)\n",
1064
+ " if response.status_code == 200:\n",
1065
+ " with open(pdb_path, 'wb') as f:\n",
1066
+ " f.write(response.content)\n",
1067
+ " return pdb_path\n",
1068
+ " return None\n",
1069
+ "\n",
1070
+ "\n",
1071
+ "def normalize_scores(scores):\n",
1072
+ " min_score = np.min(scores)\n",
1073
+ " max_score = np.max(scores)\n",
1074
+ " return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores\n",
1075
+ "\n",
1076
+ "def process_pdb(pdb_id, segment):\n",
1077
+ " pdb_path = fetch_pdb(pdb_id)\n",
1078
+ " if not pdb_path:\n",
1079
+ " return \"Failed to fetch PDB file\", None, None\n",
1080
+ " \n",
1081
+ " parser = PDBParser(QUIET=1)\n",
1082
+ " structure = parser.get_structure('protein', pdb_path)\n",
1083
+ " chain = structure[0][segment]\n",
1084
+ " \n",
1085
+ " # Comprehensive amino acid mapping\n",
1086
+ " aa_dict = {\n",
1087
+ " 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',\n",
1088
+ " 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',\n",
1089
+ " 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',\n",
1090
+ " 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',\n",
1091
+ " 'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'\n",
1092
+ " }\n",
1093
+ " \n",
1094
+ " # Exclude non-amino acid residues\n",
1095
+ " sequence = \"\".join(\n",
1096
+ " aa_dict[residue.get_resname().strip()] \n",
1097
+ " for residue in chain \n",
1098
+ " if residue.get_resname().strip() in aa_dict\n",
1099
+ " )\n",
1100
+ " \n",
1101
+ " # Prepare input for model prediction\n",
1102
+ " input_ids = tokenizer(\" \".join(sequence), return_tensors=\"pt\").input_ids.to(device)\n",
1103
+ " with torch.no_grad():\n",
1104
+ " outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()\n",
1105
+ "\n",
1106
+ " # Calculate scores and normalize them\n",
1107
+ " scores = expit(outputs[:, 1] - outputs[:, 0])\n",
1108
+ " normalized_scores = normalize_scores(scores)\n",
1109
+ " \n",
1110
+ " # Prepare the result string, including only amino acid residues\n",
1111
+ " result_str = \"\\n\".join([\n",
1112
+ " f\"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}\" \n",
1113
+ " for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict\n",
1114
+ " ])\n",
1115
+ " \n",
1116
+ " # Save predictions to file\n",
1117
+ " with open(f\"{pdb_id}_predictions.txt\", \"w\") as f:\n",
1118
+ " f.write(result_str)\n",
1119
+ " \n",
1120
+ " return result_str, pdb_path, f\"{pdb_id}_predictions.txt\"\n",
1121
+ "\n",
1122
+ "reps = [{\"model\": 0, \"style\": \"cartoon\", \"color\": \"spectrum\"}]\n",
1123
+ "\n",
1124
+ "# Gradio UI\n",
1125
+ "with gr.Blocks() as demo:\n",
1126
+ " gr.Markdown(\"# Protein Binding Site Prediction\")\n",
1127
+ "\n",
1128
+ " with gr.Row():\n",
1129
+ " pdb_input = gr.Textbox(value=\"2IWI\",\n",
1130
+ " label=\"PDB ID\",\n",
1131
+ " placeholder=\"Enter PDB ID here...\")\n",
1132
+ " segment_input = gr.Textbox(value=\"A\",\n",
1133
+ " label=\"Chain ID (Segment)\",\n",
1134
+ " placeholder=\"Enter Chain ID here...\")\n",
1135
+ " visualize_btn = gr.Button(\"Visualize Sructure\")\n",
1136
+ " prediction_btn = gr.Button(\"Predict Ligand Binding Site\")\n",
1137
+ "\n",
1138
+ " molecule_output = Molecule3D(label=\"Protein Structure\", reps=reps)\n",
1139
+ " predictions_output = gr.Textbox(label=\"Binding Site Predictions\")\n",
1140
+ " download_output = gr.File(label=\"Download Predictions\")\n",
1141
+ "\n",
1142
+ " visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)\n",
1143
+ " prediction_btn.click(\n",
1144
+ " process_pdb, \n",
1145
+ " inputs=[pdb_input, segment_input], \n",
1146
+ " outputs=[predictions_output, molecule_output, download_output]\n",
1147
+ " )\n",
1148
+ "\n",
1149
+ " gr.Markdown(\"## Examples\")\n",
1150
+ " gr.Examples(\n",
1151
+ " examples=[\n",
1152
+ " [\"2IWI\"],\n",
1153
+ " [\"7RPZ\"],\n",
1154
+ " [\"3TJN\"]\n",
1155
+ " ],\n",
1156
+ " inputs=[pdb_input, segment_input], \n",
1157
+ " outputs=[predictions_output, molecule_output, download_output]\n",
1158
+ " )\n",
1159
+ "\n",
1160
+ "demo.launch(share=True)"
1161
+ ]
1162
+ },
1163
+ {
1164
+ "cell_type": "code",
1165
+ "execution_count": null,
1166
+ "id": "4c61bac4-4f2e-4f4a-aa1f-30dca209747c",
1167
+ "metadata": {},
1168
+ "outputs": [],
1169
+ "source": []
1170
+ }
1171
+ ],
1172
+ "metadata": {
1173
+ "kernelspec": {
1174
+ "display_name": "Python (LLM)",
1175
+ "language": "python",
1176
+ "name": "llm"
1177
+ },
1178
+ "language_info": {
1179
+ "codemirror_mode": {
1180
+ "name": "ipython",
1181
+ "version": 3
1182
+ },
1183
+ "file_extension": ".py",
1184
+ "mimetype": "text/x-python",
1185
+ "name": "python",
1186
+ "nbconvert_exporter": "python",
1187
+ "pygments_lexer": "ipython3",
1188
+ "version": "3.12.7"
1189
+ }
1190
+ },
1191
+ "nbformat": 4,
1192
+ "nbformat_minor": 5
1193
+ }