ThorbenF commited on
Commit
4499595
·
1 Parent(s): 1705a41
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +49 -250
  2. app.py +49 -250
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -32,275 +32,74 @@ from Bio.PDB import PDBList
32
  from matplotlib import cm # For color mapping
33
  from matplotlib.colors import Normalize
34
 
35
- # Configuration
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
38
-
39
- # Load model and move to device
40
  model, tokenizer = load_model(checkpoint, max_length)
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  model.to(device)
43
  model.eval()
44
 
45
- reps = [
46
- {
47
- "model": 0,
48
- "chain": "",
49
- "resname": "",
50
- "style": "cartoon",
51
- "color": "spectrum",
52
- "residue_range": "",
53
- "around": 0,
54
- "byres": False,
55
- "visible": True
56
- }
57
- ]
58
 
59
- def is_valid_sequence_length(length: int) -> bool:
60
- """Check if sequence length is within valid range."""
61
- return 100 <= length <= 1500
62
-
63
- def is_nucleic_acid_chain(chain) -> bool:
64
- """Check if chain contains nucleic acids."""
65
- nucleic_acids = {'A', 'C', 'G', 'T', 'U', 'DA', 'DC', 'DG', 'DT', 'DU', 'UNK'}
66
- return any(residue.get_resname().strip() in nucleic_acids for residue in chain)
67
-
68
- def extract_protein_sequence(pdb_path):
69
- """
70
- Extract the longest protein sequence from a PDB file with improved logic
71
- """
 
 
 
 
 
72
  parser = PDBParser(QUIET=1)
73
  structure = parser.get_structure('protein', pdb_path)
 
74
 
75
- # Comprehensive amino acid mapping
76
- aa_dict = {
77
- # Standard amino acids (20 canonical)
78
- 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
79
- 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
80
- 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
81
- 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
82
-
83
- # Modified amino acids and alternative names
84
- 'MSE': 'M', # Selenomethionine
85
- 'SEP': 'S', # Phosphoserine
86
- 'TPO': 'T', # Phosphothreonine
87
- 'CSO': 'C', # Hydroxylalanine
88
- 'PTR': 'Y', # Phosphotyrosine
89
- 'HYP': 'P', # Hydroxyproline
90
- }
91
-
92
- # Ligand and nucleic acid exclusion set
93
- ligand_exclusion_set = {'HOH', 'WAT', 'DOD', 'SO4', 'PO4', 'GOL', 'ACT', 'EDO'}
94
-
95
- # Find the longest protein chain
96
- longest_sequence = ""
97
- longest_chain = None
98
-
99
- for model in structure:
100
- for chain in model:
101
- # Skip nucleic acid chains
102
- if is_nucleic_acid_chain(chain):
103
- continue
104
-
105
- # Extract and convert sequence
106
- sequence = ""
107
- for residue in chain:
108
- # Check if residue is a standard amino acid or a known modified amino acid
109
- res_name = residue.get_resname().strip()
110
- if res_name in aa_dict:
111
- sequence += aa_dict[res_name]
112
-
113
- # Check for valid length and update longest sequence
114
- if (10 < len(sequence) < 1500 and
115
- len(sequence) > len(longest_sequence)):
116
- longest_sequence = sequence
117
- longest_chain = chain
118
 
119
- if not longest_sequence:
120
- return None, None, pdb_path
121
-
122
- # Save filtered PDB if needed
123
- if longest_chain:
124
- io = PDBIO()
125
- io.set_structure(longest_chain.get_parent().get_parent())
126
- filtered_pdb_path = pdb_path.replace('.pdb', '_filtered.pdb')
127
- io.save(filtered_pdb_path)
128
- return longest_sequence, longest_chain, filtered_pdb_path
129
-
130
- return longest_sequence, longest_chain, pdb_path
131
-
132
- def create_dataset(tokenizer, seqs, labels, checkpoint):
133
- tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
134
- dataset = Dataset.from_dict(tokenized)
135
-
136
- # Adjust labels based on checkpoint
137
- if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
138
- labels = [l[:max_length-2] for l in labels]
139
- else:
140
- labels = [l[:max_length-1] for l in labels]
141
-
142
- dataset = dataset.add_column("labels", labels)
143
-
144
- return dataset
145
-
146
- def convert_predictions(input_logits):
147
- all_probs = []
148
- for logits in input_logits:
149
- logits = logits.reshape(-1, 2)
150
- probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
151
- all_probs.append(probabilities_class1)
152
-
153
- return np.concatenate(all_probs)
154
 
155
- def normalize_scores(scores):
156
- min_score = np.min(scores)
157
- max_score = np.max(scores)
158
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
159
-
160
- def predict_protein_sequence(test_one_letter_sequence):
161
- # Sanitize input sequence
162
- test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
163
- .replace("B", "X").replace("U", "X") \
164
- .replace("Z", "X").replace("J", "X")
165
-
166
- # Prepare sequence for different model types
167
- if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
168
- test_one_letter_sequence = " ".join(test_one_letter_sequence)
169
-
170
- if "ProstT5" in checkpoint:
171
- test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
172
-
173
- # Create dummy labels
174
- dummy_labels = [np.zeros(len(test_one_letter_sequence))]
175
-
176
- # Create dataset
177
- test_dataset = create_dataset(tokenizer,
178
- [test_one_letter_sequence],
179
- dummy_labels,
180
- checkpoint)
181
 
182
- # Select appropriate data collator
183
- data_collator = (DataCollatorForTokenClassification(tokenizer)
184
- if "esm" not in checkpoint and "ProstT5" not in checkpoint
185
- else DataCollatorForTokenClassification(tokenizer))
186
 
187
- # Create data loader
188
- test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
189
-
190
- # Predict
191
- for batch in test_loader:
192
- input_ids = batch['input_ids'].to(device)
193
- attention_mask = batch['attention_mask'].to(device)
194
-
195
- with torch.no_grad():
196
- outputs = model(input_ids, attention_mask=attention_mask)
197
- logits = outputs.logits.detach().cpu().numpy()
198
-
199
- # Process logits
200
- logits = logits[:, :-1] # Remove last element for prot_t5
201
- logits = convert_predictions(logits)
202
-
203
- # Normalize and format results
204
- normalized_scores = normalize_scores(logits)
205
- test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
206
-
207
- return test_one_letter_sequence, normalized_scores
208
-
209
- def fetch_pdb(pdb_id):
210
- try:
211
- # Create a directory to store PDB files if it doesn't exist
212
- os.makedirs('pdb_files', exist_ok=True)
213
-
214
- # Fetch the PDB structure from RCSB
215
- pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
216
- pdb_path = f'pdb_files/{pdb_id}.pdb'
217
-
218
- # Download the file
219
- response = requests.get(pdb_url)
220
-
221
- if response.status_code == 200:
222
- with open(pdb_path, 'wb') as f:
223
- f.write(response.content)
224
- return pdb_path
225
- else:
226
- return None
227
-
228
- except Exception as e:
229
- print(f"Error fetching PDB: {e}")
230
- return None
231
-
232
- def score_to_color(score):
233
- norm = Normalize(vmin=0, vmax=1) # Normalize scores between 0 and 1
234
- color_map = cm.coolwarm # Directly use the colormap (e.g., 'cividis', 'coolwarm', etc.)
235
- rgba = color_map(norm(score)) # Get RGBA values
236
- hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255))
237
- return hex_color
238
-
239
- def process_pdb(pdb_id):
240
- # Fetch PDB file
241
- pdbl = PDBList()
242
- pdb_path = pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb')
243
-
244
- if not pdb_path or not os.path.exists(pdb_path):
245
- return "Failed to fetch PDB file", None
246
-
247
- # Extract protein sequence and chain
248
- protein_sequence, chain, filtered_pdb_path = extract_protein_sequence(pdb_path)
249
 
250
- if not protein_sequence:
251
- return "No suitable protein sequence found", None
252
-
253
- # Predict binding sites
254
- sequence, normalized_scores = predict_protein_sequence(protein_sequence)
255
-
256
- # Prepare result string
257
- result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)])
258
-
259
- pdb_path = fetch_pdb(pdb_id)
260
-
261
- return result_str, pdb_path
262
-
263
- # Create Gradio interface
264
  with gr.Blocks() as demo:
265
  gr.Markdown("# Protein Binding Site Prediction")
266
 
267
  with gr.Row():
268
- with gr.Column():
269
- pdb_input = gr.Textbox(
270
- value="2IWI",
271
- label="PDB ID",
272
- placeholder="Enter PDB ID here..."
273
- )
274
- predict_btn = gr.Button("Predict Binding Sites")
275
-
276
- with gr.Column():
277
- # Binding site predictions output
278
- predictions_output = gr.Textbox(
279
- label="Binding Site Predictions"
280
- )
281
-
282
- # 3D Molecule visualization
283
- molecule_output = Molecule3D(
284
- label="Protein Structure",
285
- reps=reps
286
- )
287
-
288
- # Prediction logic
289
- predict_btn.click(
290
- process_pdb,
291
- inputs=[pdb_input],
292
- outputs=[predictions_output, molecule_output]
293
- )
294
-
295
- gr.Markdown("## Examples")
296
- gr.Examples(
297
- examples=[
298
- ["2IWI"],
299
- ["7RPZ"],
300
- ["3TJN"]
301
- ],
302
- inputs=[pdb_input],
303
- outputs=[predictions_output, molecule_output]
304
  )
305
 
306
- demo.launch()
 
32
  from matplotlib import cm # For color mapping
33
  from matplotlib.colors import Normalize
34
 
35
+ # Load model and move to device
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
 
 
38
  model, tokenizer = load_model(checkpoint, max_length)
39
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  model.to(device)
41
  model.eval()
42
 
43
+ reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Function to fetch a PDB file
46
+ def fetch_pdb(pdb_id):
47
+ pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
48
+ pdb_path = f'pdb_files/{pdb_id}.pdb'
49
+ os.makedirs('pdb_files', exist_ok=True)
50
+ response = requests.get(pdb_url)
51
+ if response.status_code == 200:
52
+ with open(pdb_path, 'wb') as f:
53
+ f.write(response.content)
54
+ return pdb_path
55
+ return None
56
+
57
+ # Extract sequence and predict binding scores
58
+ def process_pdb(pdb_id, segment):
59
+ pdb_path = fetch_pdb(pdb_id)
60
+ if not pdb_path:
61
+ return "Failed to fetch PDB file", None, None
62
+
63
  parser = PDBParser(QUIET=1)
64
  structure = parser.get_structure('protein', pdb_path)
65
+ chain = structure[0][segment]
66
 
67
+ sequence = "".join(residue.get_resname().strip() for residue in chain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
70
+ with torch.no_grad():
71
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ scores = outputs[:, 1] - outputs[:, 0]
74
+ result_str = "\n".join([
75
+ f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
76
+ for i, res in enumerate(chain)
77
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ with open(f"{pdb_id}_predictions.txt", "w") as f:
80
+ f.write(result_str)
 
 
81
 
82
+ return result_str, pdb_path, f"{pdb_id}_predictions.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown("# Protein Binding Site Prediction")
87
 
88
  with gr.Row():
89
+ pdb_input = gr.Textbox(label="PDB ID")
90
+ segment_input = gr.Textbox(label="Segment (Chain ID)")
91
+ visualize_btn = gr.Button("Visualize")
92
+ prediction_btn = gr.Button("Predict")
93
+
94
+ molecule_output = Molecule3D(label="Protein Structure", reps=reps)
95
+ predictions_output = gr.Textbox(label="Binding Site Predictions")
96
+ download_output = gr.File(label="Download Predictions")
97
+
98
+ visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
99
+ prediction_btn.click(
100
+ process_pdb,
101
+ inputs=[pdb_input, segment_input],
102
+ outputs=[predictions_output, molecule_output, download_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
 
105
+ demo.launch(share=True)
app.py CHANGED
@@ -32,275 +32,74 @@ from Bio.PDB import PDBList
32
  from matplotlib import cm # For color mapping
33
  from matplotlib.colors import Normalize
34
 
35
- # Configuration
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
38
-
39
- # Load model and move to device
40
  model, tokenizer = load_model(checkpoint, max_length)
41
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
  model.to(device)
43
  model.eval()
44
 
45
- reps = [
46
- {
47
- "model": 0,
48
- "chain": "",
49
- "resname": "",
50
- "style": "cartoon",
51
- "color": "spectrum",
52
- "residue_range": "",
53
- "around": 0,
54
- "byres": False,
55
- "visible": True
56
- }
57
- ]
58
 
59
- def is_valid_sequence_length(length: int) -> bool:
60
- """Check if sequence length is within valid range."""
61
- return 100 <= length <= 1500
62
-
63
- def is_nucleic_acid_chain(chain) -> bool:
64
- """Check if chain contains nucleic acids."""
65
- nucleic_acids = {'A', 'C', 'G', 'T', 'U', 'DA', 'DC', 'DG', 'DT', 'DU', 'UNK'}
66
- return any(residue.get_resname().strip() in nucleic_acids for residue in chain)
67
-
68
- def extract_protein_sequence(pdb_path):
69
- """
70
- Extract the longest protein sequence from a PDB file with improved logic
71
- """
 
 
 
 
 
72
  parser = PDBParser(QUIET=1)
73
  structure = parser.get_structure('protein', pdb_path)
 
74
 
75
- # Comprehensive amino acid mapping
76
- aa_dict = {
77
- # Standard amino acids (20 canonical)
78
- 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
79
- 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
80
- 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
81
- 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
82
-
83
- # Modified amino acids and alternative names
84
- 'MSE': 'M', # Selenomethionine
85
- 'SEP': 'S', # Phosphoserine
86
- 'TPO': 'T', # Phosphothreonine
87
- 'CSO': 'C', # Hydroxylalanine
88
- 'PTR': 'Y', # Phosphotyrosine
89
- 'HYP': 'P', # Hydroxyproline
90
- }
91
-
92
- # Ligand and nucleic acid exclusion set
93
- ligand_exclusion_set = {'HOH', 'WAT', 'DOD', 'SO4', 'PO4', 'GOL', 'ACT', 'EDO'}
94
-
95
- # Find the longest protein chain
96
- longest_sequence = ""
97
- longest_chain = None
98
-
99
- for model in structure:
100
- for chain in model:
101
- # Skip nucleic acid chains
102
- if is_nucleic_acid_chain(chain):
103
- continue
104
-
105
- # Extract and convert sequence
106
- sequence = ""
107
- for residue in chain:
108
- # Check if residue is a standard amino acid or a known modified amino acid
109
- res_name = residue.get_resname().strip()
110
- if res_name in aa_dict:
111
- sequence += aa_dict[res_name]
112
-
113
- # Check for valid length and update longest sequence
114
- if (10 < len(sequence) < 1500 and
115
- len(sequence) > len(longest_sequence)):
116
- longest_sequence = sequence
117
- longest_chain = chain
118
 
119
- if not longest_sequence:
120
- return None, None, pdb_path
121
-
122
- # Save filtered PDB if needed
123
- if longest_chain:
124
- io = PDBIO()
125
- io.set_structure(longest_chain.get_parent().get_parent())
126
- filtered_pdb_path = pdb_path.replace('.pdb', '_filtered.pdb')
127
- io.save(filtered_pdb_path)
128
- return longest_sequence, longest_chain, filtered_pdb_path
129
-
130
- return longest_sequence, longest_chain, pdb_path
131
-
132
- def create_dataset(tokenizer, seqs, labels, checkpoint):
133
- tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
134
- dataset = Dataset.from_dict(tokenized)
135
-
136
- # Adjust labels based on checkpoint
137
- if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
138
- labels = [l[:max_length-2] for l in labels]
139
- else:
140
- labels = [l[:max_length-1] for l in labels]
141
-
142
- dataset = dataset.add_column("labels", labels)
143
-
144
- return dataset
145
-
146
- def convert_predictions(input_logits):
147
- all_probs = []
148
- for logits in input_logits:
149
- logits = logits.reshape(-1, 2)
150
- probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
151
- all_probs.append(probabilities_class1)
152
-
153
- return np.concatenate(all_probs)
154
 
155
- def normalize_scores(scores):
156
- min_score = np.min(scores)
157
- max_score = np.max(scores)
158
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
159
-
160
- def predict_protein_sequence(test_one_letter_sequence):
161
- # Sanitize input sequence
162
- test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
163
- .replace("B", "X").replace("U", "X") \
164
- .replace("Z", "X").replace("J", "X")
165
-
166
- # Prepare sequence for different model types
167
- if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
168
- test_one_letter_sequence = " ".join(test_one_letter_sequence)
169
-
170
- if "ProstT5" in checkpoint:
171
- test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
172
-
173
- # Create dummy labels
174
- dummy_labels = [np.zeros(len(test_one_letter_sequence))]
175
-
176
- # Create dataset
177
- test_dataset = create_dataset(tokenizer,
178
- [test_one_letter_sequence],
179
- dummy_labels,
180
- checkpoint)
181
 
182
- # Select appropriate data collator
183
- data_collator = (DataCollatorForTokenClassification(tokenizer)
184
- if "esm" not in checkpoint and "ProstT5" not in checkpoint
185
- else DataCollatorForTokenClassification(tokenizer))
186
 
187
- # Create data loader
188
- test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
189
-
190
- # Predict
191
- for batch in test_loader:
192
- input_ids = batch['input_ids'].to(device)
193
- attention_mask = batch['attention_mask'].to(device)
194
-
195
- with torch.no_grad():
196
- outputs = model(input_ids, attention_mask=attention_mask)
197
- logits = outputs.logits.detach().cpu().numpy()
198
-
199
- # Process logits
200
- logits = logits[:, :-1] # Remove last element for prot_t5
201
- logits = convert_predictions(logits)
202
-
203
- # Normalize and format results
204
- normalized_scores = normalize_scores(logits)
205
- test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
206
-
207
- return test_one_letter_sequence, normalized_scores
208
-
209
- def fetch_pdb(pdb_id):
210
- try:
211
- # Create a directory to store PDB files if it doesn't exist
212
- os.makedirs('pdb_files', exist_ok=True)
213
-
214
- # Fetch the PDB structure from RCSB
215
- pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
216
- pdb_path = f'pdb_files/{pdb_id}.pdb'
217
-
218
- # Download the file
219
- response = requests.get(pdb_url)
220
-
221
- if response.status_code == 200:
222
- with open(pdb_path, 'wb') as f:
223
- f.write(response.content)
224
- return pdb_path
225
- else:
226
- return None
227
-
228
- except Exception as e:
229
- print(f"Error fetching PDB: {e}")
230
- return None
231
-
232
- def score_to_color(score):
233
- norm = Normalize(vmin=0, vmax=1) # Normalize scores between 0 and 1
234
- color_map = cm.coolwarm # Directly use the colormap (e.g., 'cividis', 'coolwarm', etc.)
235
- rgba = color_map(norm(score)) # Get RGBA values
236
- hex_color = '#{:02x}{:02x}{:02x}'.format(int(rgba[0] * 255), int(rgba[1] * 255), int(rgba[2] * 255))
237
- return hex_color
238
-
239
- def process_pdb(pdb_id):
240
- # Fetch PDB file
241
- pdbl = PDBList()
242
- pdb_path = pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb')
243
-
244
- if not pdb_path or not os.path.exists(pdb_path):
245
- return "Failed to fetch PDB file", None
246
-
247
- # Extract protein sequence and chain
248
- protein_sequence, chain, filtered_pdb_path = extract_protein_sequence(pdb_path)
249
 
250
- if not protein_sequence:
251
- return "No suitable protein sequence found", None
252
-
253
- # Predict binding sites
254
- sequence, normalized_scores = predict_protein_sequence(protein_sequence)
255
-
256
- # Prepare result string
257
- result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)])
258
-
259
- pdb_path = fetch_pdb(pdb_id)
260
-
261
- return result_str, pdb_path
262
-
263
- # Create Gradio interface
264
  with gr.Blocks() as demo:
265
  gr.Markdown("# Protein Binding Site Prediction")
266
 
267
  with gr.Row():
268
- with gr.Column():
269
- pdb_input = gr.Textbox(
270
- value="2IWI",
271
- label="PDB ID",
272
- placeholder="Enter PDB ID here..."
273
- )
274
- predict_btn = gr.Button("Predict Binding Sites")
275
-
276
- with gr.Column():
277
- # Binding site predictions output
278
- predictions_output = gr.Textbox(
279
- label="Binding Site Predictions"
280
- )
281
-
282
- # 3D Molecule visualization
283
- molecule_output = Molecule3D(
284
- label="Protein Structure",
285
- reps=reps
286
- )
287
-
288
- # Prediction logic
289
- predict_btn.click(
290
- process_pdb,
291
- inputs=[pdb_input],
292
- outputs=[predictions_output, molecule_output]
293
- )
294
-
295
- gr.Markdown("## Examples")
296
- gr.Examples(
297
- examples=[
298
- ["2IWI"],
299
- ["7RPZ"],
300
- ["3TJN"]
301
- ],
302
- inputs=[pdb_input],
303
- outputs=[predictions_output, molecule_output]
304
  )
305
 
306
- demo.launch()
 
32
  from matplotlib import cm # For color mapping
33
  from matplotlib.colors import Normalize
34
 
35
+ # Load model and move to device
36
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
37
  max_length = 1500
 
 
38
  model, tokenizer = load_model(checkpoint, max_length)
39
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  model.to(device)
41
  model.eval()
42
 
43
+ reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Function to fetch a PDB file
46
+ def fetch_pdb(pdb_id):
47
+ pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
48
+ pdb_path = f'pdb_files/{pdb_id}.pdb'
49
+ os.makedirs('pdb_files', exist_ok=True)
50
+ response = requests.get(pdb_url)
51
+ if response.status_code == 200:
52
+ with open(pdb_path, 'wb') as f:
53
+ f.write(response.content)
54
+ return pdb_path
55
+ return None
56
+
57
+ # Extract sequence and predict binding scores
58
+ def process_pdb(pdb_id, segment):
59
+ pdb_path = fetch_pdb(pdb_id)
60
+ if not pdb_path:
61
+ return "Failed to fetch PDB file", None, None
62
+
63
  parser = PDBParser(QUIET=1)
64
  structure = parser.get_structure('protein', pdb_path)
65
+ chain = structure[0][segment]
66
 
67
+ sequence = "".join(residue.get_resname().strip() for residue in chain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
70
+ with torch.no_grad():
71
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ scores = outputs[:, 1] - outputs[:, 0]
74
+ result_str = "\n".join([
75
+ f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
76
+ for i, res in enumerate(chain)
77
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ with open(f"{pdb_id}_predictions.txt", "w") as f:
80
+ f.write(result_str)
 
 
81
 
82
+ return result_str, pdb_path, f"{pdb_id}_predictions.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown("# Protein Binding Site Prediction")
87
 
88
  with gr.Row():
89
+ pdb_input = gr.Textbox(label="PDB ID")
90
+ segment_input = gr.Textbox(label="Segment (Chain ID)")
91
+ visualize_btn = gr.Button("Visualize")
92
+ prediction_btn = gr.Button("Predict")
93
+
94
+ molecule_output = Molecule3D(label="Protein Structure", reps=reps)
95
+ predictions_output = gr.Textbox(label="Binding Site Predictions")
96
+ download_output = gr.File(label="Download Predictions")
97
+
98
+ visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
99
+ prediction_btn.click(
100
+ process_pdb,
101
+ inputs=[pdb_input, segment_input],
102
+ outputs=[predictions_output, molecule_output, download_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
 
105
+ demo.launch(share=True)