ThorbenF commited on
Commit
11bcc1a
·
1 Parent(s): 01ff8b6
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -4,7 +4,6 @@ from model_loader import load_model
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
  from torch.utils.data import DataLoader
9
 
10
  import re
@@ -14,53 +13,25 @@ import pandas as pd
14
  import copy
15
 
16
  import transformers, datasets
17
- from transformers.modeling_outputs import TokenClassifierOutput
18
- from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
19
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
- from transformers import T5EncoderModel, T5Tokenizer
21
- from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
22
  from transformers import AutoTokenizer
23
- from transformers import TrainingArguments, Trainer, set_seed
24
  from transformers import DataCollatorForTokenClassification
25
 
26
- from dataclasses import dataclass
27
- from typing import Dict, List, Optional, Tuple, Union
28
-
29
- # for custom DataCollator
30
- from transformers.data.data_collator import DataCollatorMixin
31
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
32
- from transformers.utils import PaddingStrategy
33
-
34
  from datasets import Dataset
35
 
36
  from scipy.special import expit
37
 
38
  import requests
39
 
40
- from gradio_molecule3d import Molecule3D
 
 
41
 
42
- #import peft
43
- #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
  # Configuration
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
49
- # Default representations for molecule rendering
50
- reps = [
51
- {
52
- "model": 0,
53
- "chain": "",
54
- "resname": "",
55
- "style": "cartoon",
56
- "color": "spectrum",
57
- "residue_range": "",
58
- "around": 0,
59
- "byres": False,
60
- "visible": True
61
- }
62
- ]
63
-
64
  # Load model and move to device
65
  model, tokenizer = load_model(checkpoint, max_length)
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -142,9 +113,7 @@ def predict_protein_sequence(test_one_letter_sequence):
142
  normalized_scores = normalize_scores(logits)
143
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
144
 
145
- result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
146
-
147
- return result_str
148
 
149
  def fetch_pdb(pdb_id):
150
  try:
@@ -169,14 +138,88 @@ def fetch_pdb(pdb_id):
169
  print(f"Error fetching PDB: {e}")
170
  return None
171
 
172
- def process_input(sequence, pdb_id):
173
- # Predict binding sites
174
- binding_site_predictions = predict_protein_sequence(sequence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
 
 
 
176
  # Fetch PDB file
177
  pdb_path = fetch_pdb(pdb_id)
178
 
179
- return binding_site_predictions, pdb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # Create Gradio interface
182
  with gr.Blocks() as demo:
@@ -184,18 +227,11 @@ with gr.Blocks() as demo:
184
 
185
  with gr.Row():
186
  with gr.Column():
187
- # Sequence input
188
- sequence_input = gr.Textbox(
189
- lines=2,
190
- placeholder="Enter protein sequence here...",
191
- label="Protein Sequence"
192
- )
193
-
194
- # PDB ID input
195
  pdb_input = gr.Textbox(
196
- lines=1,
197
- placeholder="Enter PDB ID here...",
198
- label="PDB ID for 3D Visualization"
199
  )
200
 
201
  # Predict button
@@ -210,24 +246,26 @@ with gr.Blocks() as demo:
210
  # 3D Molecule visualization
211
  molecule_output = Molecule3D(
212
  label="Protein Structure",
213
- reps=reps
214
  )
215
 
216
  # Prediction logic
217
  predict_btn.click(
218
- process_input,
219
- inputs=[sequence_input, pdb_input],
220
- outputs=[predictions_output, molecule_output]
221
  )
222
 
223
  # Add some example inputs
224
  gr.Markdown("## Examples")
225
  gr.Examples(
226
  examples=[
227
- ["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
 
 
228
  ],
229
- inputs=[sequence_input, pdb_input],
230
- outputs=[predictions_output, molecule_output]
231
  )
232
 
233
  demo.launch()
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
7
  from torch.utils.data import DataLoader
8
 
9
  import re
 
13
  import copy
14
 
15
  import transformers, datasets
 
 
 
 
 
16
  from transformers import AutoTokenizer
 
17
  from transformers import DataCollatorForTokenClassification
18
 
 
 
 
 
 
 
 
 
19
  from datasets import Dataset
20
 
21
  from scipy.special import expit
22
 
23
  import requests
24
 
25
+ # Biopython imports
26
+ from Bio.PDB import PDBParser, Select
27
+ from Bio.PDB.DSSP import DSSP
28
 
29
+ from gradio_molecule3d import Molecule3D
 
30
 
31
  # Configuration
32
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
  max_length = 1500
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Load model and move to device
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
113
  normalized_scores = normalize_scores(logits)
114
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
115
 
116
+ return test_one_letter_sequence, normalized_scores
 
 
117
 
118
  def fetch_pdb(pdb_id):
119
  try:
 
138
  print(f"Error fetching PDB: {e}")
139
  return None
140
 
141
+ def extract_protein_sequence(pdb_path):
142
+ """
143
+ Extract the longest protein sequence from a PDB file
144
+ """
145
+ parser = PDBParser(QUIET=1)
146
+ structure = parser.get_structure('protein', pdb_path)
147
+
148
+ class ProteinSelect(Select):
149
+ def accept_residue(self, residue):
150
+ # Only accept standard amino acids
151
+ standard_aa = set('ACDEFGHIKLMNPQRSTVWY')
152
+ return residue.get_resname() in standard_aa
153
+
154
+ # Find the longest protein chain
155
+ longest_sequence = ""
156
+ longest_chain = None
157
+ for model in structure:
158
+ for chain in model:
159
+ sequence = ""
160
+ for residue in chain:
161
+ if Select().accept_residue(residue):
162
+ sequence += residue.get_resname()
163
+
164
+ # Convert 3-letter amino acid codes to 1-letter
165
+ aa_dict = {
166
+ 'ALA':'A', 'CYS':'C', 'ASP':'D', 'GLU':'E', 'PHE':'F',
167
+ 'GLY':'G', 'HIS':'H', 'ILE':'I', 'LYS':'K', 'LEU':'L',
168
+ 'MET':'M', 'ASN':'N', 'PRO':'P', 'GLN':'Q', 'ARG':'R',
169
+ 'SER':'S', 'THR':'T', 'VAL':'V', 'TRP':'W', 'TYR':'Y'
170
+ }
171
+
172
+ one_letter_sequence = ''.join([aa_dict.get(res, 'X') for res in sequence])
173
+
174
+ # Track the longest sequence
175
+ if len(one_letter_sequence) > len(longest_sequence) and \
176
+ 10 < len(one_letter_sequence) < 1500:
177
+ longest_sequence = one_letter_sequence
178
+ longest_chain = chain
179
 
180
+ return longest_sequence, longest_chain
181
+
182
+ def process_pdb(pdb_id):
183
  # Fetch PDB file
184
  pdb_path = fetch_pdb(pdb_id)
185
 
186
+ if not pdb_path:
187
+ return "Failed to fetch PDB file", None, None
188
+
189
+ # Extract protein sequence and chain
190
+ protein_sequence, chain = extract_protein_sequence(pdb_path)
191
+
192
+ if not protein_sequence:
193
+ return "No suitable protein sequence found", None, None
194
+
195
+ # Predict binding sites
196
+ sequence, normalized_scores = predict_protein_sequence(protein_sequence)
197
+
198
+ # Prepare representations for coloring residues
199
+ reps = []
200
+ for i, (res, score) in enumerate(zip(sequence, normalized_scores), start=1):
201
+ # Map score to a color gradient from blue (low) to red (high)
202
+ color_intensity = int(score * 255)
203
+ color = f'rgb({color_intensity}, 0, {255-color_intensity})'
204
+
205
+ rep = {
206
+ "model": 0,
207
+ "chain": chain.id,
208
+ "resname": res,
209
+ "resnum": i,
210
+ "style": "cartoon",
211
+ "color": color,
212
+ "residue_range": f"{i}-{i}",
213
+ "around": 0,
214
+ "byres": True,
215
+ "visible": True
216
+ }
217
+ reps.append(rep)
218
+
219
+ # Prepare result string
220
+ result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)])
221
+
222
+ return result_str, reps, pdb_path
223
 
224
  # Create Gradio interface
225
  with gr.Blocks() as demo:
 
227
 
228
  with gr.Row():
229
  with gr.Column():
230
+ # PDB ID input with default suggestion
 
 
 
 
 
 
 
231
  pdb_input = gr.Textbox(
232
+ value="2IWI",
233
+ label="PDB ID",
234
+ placeholder="Enter PDB ID here..."
235
  )
236
 
237
  # Predict button
 
246
  # 3D Molecule visualization
247
  molecule_output = Molecule3D(
248
  label="Protein Structure",
249
+ reps=[] # Start with empty representations
250
  )
251
 
252
  # Prediction logic
253
  predict_btn.click(
254
+ process_pdb,
255
+ inputs=[pdb_input],
256
+ outputs=[predictions_output, molecule_output, molecule_output]
257
  )
258
 
259
  # Add some example inputs
260
  gr.Markdown("## Examples")
261
  gr.Examples(
262
  examples=[
263
+ ["2IWI"],
264
+ ["1ABC"],
265
+ ["4HHB"]
266
  ],
267
+ inputs=[pdb_input],
268
+ outputs=[predictions_output, molecule_output, molecule_output]
269
  )
270
 
271
  demo.launch()
.ipynb_checkpoints/requirements-checkpoint.txt CHANGED
@@ -9,4 +9,5 @@ scikit-learn>=0.24.0
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
- gradio_molecule3d
 
 
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
+ gradio_molecule3d
13
+ biopython>=1.81
app.py CHANGED
@@ -4,7 +4,6 @@ from model_loader import load_model
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
  from torch.utils.data import DataLoader
9
 
10
  import re
@@ -14,53 +13,25 @@ import pandas as pd
14
  import copy
15
 
16
  import transformers, datasets
17
- from transformers.modeling_outputs import TokenClassifierOutput
18
- from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
19
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
- from transformers import T5EncoderModel, T5Tokenizer
21
- from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
22
  from transformers import AutoTokenizer
23
- from transformers import TrainingArguments, Trainer, set_seed
24
  from transformers import DataCollatorForTokenClassification
25
 
26
- from dataclasses import dataclass
27
- from typing import Dict, List, Optional, Tuple, Union
28
-
29
- # for custom DataCollator
30
- from transformers.data.data_collator import DataCollatorMixin
31
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
32
- from transformers.utils import PaddingStrategy
33
-
34
  from datasets import Dataset
35
 
36
  from scipy.special import expit
37
 
38
  import requests
39
 
40
- from gradio_molecule3d import Molecule3D
 
 
41
 
42
- #import peft
43
- #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
  # Configuration
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
49
- # Default representations for molecule rendering
50
- reps = [
51
- {
52
- "model": 0,
53
- "chain": "",
54
- "resname": "",
55
- "style": "cartoon",
56
- "color": "spectrum",
57
- "residue_range": "",
58
- "around": 0,
59
- "byres": False,
60
- "visible": True
61
- }
62
- ]
63
-
64
  # Load model and move to device
65
  model, tokenizer = load_model(checkpoint, max_length)
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -142,9 +113,7 @@ def predict_protein_sequence(test_one_letter_sequence):
142
  normalized_scores = normalize_scores(logits)
143
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
144
 
145
- result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
146
-
147
- return result_str
148
 
149
  def fetch_pdb(pdb_id):
150
  try:
@@ -169,14 +138,88 @@ def fetch_pdb(pdb_id):
169
  print(f"Error fetching PDB: {e}")
170
  return None
171
 
172
- def process_input(sequence, pdb_id):
173
- # Predict binding sites
174
- binding_site_predictions = predict_protein_sequence(sequence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
 
 
 
176
  # Fetch PDB file
177
  pdb_path = fetch_pdb(pdb_id)
178
 
179
- return binding_site_predictions, pdb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # Create Gradio interface
182
  with gr.Blocks() as demo:
@@ -184,18 +227,11 @@ with gr.Blocks() as demo:
184
 
185
  with gr.Row():
186
  with gr.Column():
187
- # Sequence input
188
- sequence_input = gr.Textbox(
189
- lines=2,
190
- placeholder="Enter protein sequence here...",
191
- label="Protein Sequence"
192
- )
193
-
194
- # PDB ID input
195
  pdb_input = gr.Textbox(
196
- lines=1,
197
- placeholder="Enter PDB ID here...",
198
- label="PDB ID for 3D Visualization"
199
  )
200
 
201
  # Predict button
@@ -210,24 +246,26 @@ with gr.Blocks() as demo:
210
  # 3D Molecule visualization
211
  molecule_output = Molecule3D(
212
  label="Protein Structure",
213
- reps=reps
214
  )
215
 
216
  # Prediction logic
217
  predict_btn.click(
218
- process_input,
219
- inputs=[sequence_input, pdb_input],
220
- outputs=[predictions_output, molecule_output]
221
  )
222
 
223
  # Add some example inputs
224
  gr.Markdown("## Examples")
225
  gr.Examples(
226
  examples=[
227
- ["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
 
 
228
  ],
229
- inputs=[sequence_input, pdb_input],
230
- outputs=[predictions_output, molecule_output]
231
  )
232
 
233
  demo.launch()
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
7
  from torch.utils.data import DataLoader
8
 
9
  import re
 
13
  import copy
14
 
15
  import transformers, datasets
 
 
 
 
 
16
  from transformers import AutoTokenizer
 
17
  from transformers import DataCollatorForTokenClassification
18
 
 
 
 
 
 
 
 
 
19
  from datasets import Dataset
20
 
21
  from scipy.special import expit
22
 
23
  import requests
24
 
25
+ # Biopython imports
26
+ from Bio.PDB import PDBParser, Select
27
+ from Bio.PDB.DSSP import DSSP
28
 
29
+ from gradio_molecule3d import Molecule3D
 
30
 
31
  # Configuration
32
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
  max_length = 1500
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Load model and move to device
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
113
  normalized_scores = normalize_scores(logits)
114
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
115
 
116
+ return test_one_letter_sequence, normalized_scores
 
 
117
 
118
  def fetch_pdb(pdb_id):
119
  try:
 
138
  print(f"Error fetching PDB: {e}")
139
  return None
140
 
141
+ def extract_protein_sequence(pdb_path):
142
+ """
143
+ Extract the longest protein sequence from a PDB file
144
+ """
145
+ parser = PDBParser(QUIET=1)
146
+ structure = parser.get_structure('protein', pdb_path)
147
+
148
+ class ProteinSelect(Select):
149
+ def accept_residue(self, residue):
150
+ # Only accept standard amino acids
151
+ standard_aa = set('ACDEFGHIKLMNPQRSTVWY')
152
+ return residue.get_resname() in standard_aa
153
+
154
+ # Find the longest protein chain
155
+ longest_sequence = ""
156
+ longest_chain = None
157
+ for model in structure:
158
+ for chain in model:
159
+ sequence = ""
160
+ for residue in chain:
161
+ if Select().accept_residue(residue):
162
+ sequence += residue.get_resname()
163
+
164
+ # Convert 3-letter amino acid codes to 1-letter
165
+ aa_dict = {
166
+ 'ALA':'A', 'CYS':'C', 'ASP':'D', 'GLU':'E', 'PHE':'F',
167
+ 'GLY':'G', 'HIS':'H', 'ILE':'I', 'LYS':'K', 'LEU':'L',
168
+ 'MET':'M', 'ASN':'N', 'PRO':'P', 'GLN':'Q', 'ARG':'R',
169
+ 'SER':'S', 'THR':'T', 'VAL':'V', 'TRP':'W', 'TYR':'Y'
170
+ }
171
+
172
+ one_letter_sequence = ''.join([aa_dict.get(res, 'X') for res in sequence])
173
+
174
+ # Track the longest sequence
175
+ if len(one_letter_sequence) > len(longest_sequence) and \
176
+ 10 < len(one_letter_sequence) < 1500:
177
+ longest_sequence = one_letter_sequence
178
+ longest_chain = chain
179
 
180
+ return longest_sequence, longest_chain
181
+
182
+ def process_pdb(pdb_id):
183
  # Fetch PDB file
184
  pdb_path = fetch_pdb(pdb_id)
185
 
186
+ if not pdb_path:
187
+ return "Failed to fetch PDB file", None, None
188
+
189
+ # Extract protein sequence and chain
190
+ protein_sequence, chain = extract_protein_sequence(pdb_path)
191
+
192
+ if not protein_sequence:
193
+ return "No suitable protein sequence found", None, None
194
+
195
+ # Predict binding sites
196
+ sequence, normalized_scores = predict_protein_sequence(protein_sequence)
197
+
198
+ # Prepare representations for coloring residues
199
+ reps = []
200
+ for i, (res, score) in enumerate(zip(sequence, normalized_scores), start=1):
201
+ # Map score to a color gradient from blue (low) to red (high)
202
+ color_intensity = int(score * 255)
203
+ color = f'rgb({color_intensity}, 0, {255-color_intensity})'
204
+
205
+ rep = {
206
+ "model": 0,
207
+ "chain": chain.id,
208
+ "resname": res,
209
+ "resnum": i,
210
+ "style": "cartoon",
211
+ "color": color,
212
+ "residue_range": f"{i}-{i}",
213
+ "around": 0,
214
+ "byres": True,
215
+ "visible": True
216
+ }
217
+ reps.append(rep)
218
+
219
+ # Prepare result string
220
+ result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)])
221
+
222
+ return result_str, reps, pdb_path
223
 
224
  # Create Gradio interface
225
  with gr.Blocks() as demo:
 
227
 
228
  with gr.Row():
229
  with gr.Column():
230
+ # PDB ID input with default suggestion
 
 
 
 
 
 
 
231
  pdb_input = gr.Textbox(
232
+ value="2IWI",
233
+ label="PDB ID",
234
+ placeholder="Enter PDB ID here..."
235
  )
236
 
237
  # Predict button
 
246
  # 3D Molecule visualization
247
  molecule_output = Molecule3D(
248
  label="Protein Structure",
249
+ reps=[] # Start with empty representations
250
  )
251
 
252
  # Prediction logic
253
  predict_btn.click(
254
+ process_pdb,
255
+ inputs=[pdb_input],
256
+ outputs=[predictions_output, molecule_output, molecule_output]
257
  )
258
 
259
  # Add some example inputs
260
  gr.Markdown("## Examples")
261
  gr.Examples(
262
  examples=[
263
+ ["2IWI"],
264
+ ["1ABC"],
265
+ ["4HHB"]
266
  ],
267
+ inputs=[pdb_input],
268
+ outputs=[predictions_output, molecule_output, molecule_output]
269
  )
270
 
271
  demo.launch()
requirements.txt CHANGED
@@ -9,4 +9,5 @@ scikit-learn>=0.24.0
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
- gradio_molecule3d
 
 
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
+ gradio_molecule3d
13
+ biopython>=1.81