ThorbenFroehlking commited on
Commit
17ad0e5
·
1 Parent(s): 5dbe94b
Files changed (3) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +158 -144
  2. app-Copy1.py +89 -22
  3. app.py +158 -144
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -27,10 +27,7 @@ from datasets import Dataset
27
 
28
  from scipy.special import expit
29
 
30
-
31
  # Load model and move to device
32
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
34
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
@@ -45,45 +42,32 @@ def normalize_scores(scores):
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
- try:
49
- with open(pdb_path, 'r') as f:
50
- return f.read()
51
- except FileNotFoundError:
52
- print(f"File not found: {pdb_path}")
53
- raise
54
- except Exception as e:
55
- print(f"Error reading file {pdb_path}: {str(e)}")
56
- raise
57
-
58
- def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
59
  """
60
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
61
  If a structure file already exists locally, it uses that.
62
  """
63
  file_path = download_structure(pdb_id, output_dir)
64
- if file_path:
65
- return file_path
66
- else:
67
- return None
68
 
69
- def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:
70
  """
71
  Attempt to download the structure file in CIF or PDB format.
72
- Returns the path to the downloaded file, or None if download fails.
73
  """
74
  for ext in ['.cif', '.pdb']:
75
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
76
  if os.path.exists(file_path):
77
  return file_path
78
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
79
- try:
80
- response = requests.get(url, timeout=10)
81
- if response.status_code == 200:
82
- with open(file_path, 'wb') as f:
83
- f.write(response.content)
84
- return file_path
85
- except Exception as e:
86
- print(f"Download error for {pdb_id}{ext}: {e}")
87
  return None
88
 
89
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
@@ -100,8 +84,6 @@ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
100
 
101
  def fetch_pdb(pdb_id):
102
  pdb_path = fetch_structure(pdb_id)
103
- if not pdb_path:
104
- return None
105
  _, ext = os.path.splitext(pdb_path)
106
  if ext == '.cif':
107
  pdb_path = convert_cif_to_pdb(pdb_path)
@@ -111,11 +93,9 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
111
  """
112
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
113
  """
114
- # Read the original PDB file
115
  parser = PDBParser(QUIET=True)
116
  structure = parser.get_structure('protein', input_pdb)
117
 
118
- # Prepare a new structure with only the specified chain and selected residues
119
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
120
 
121
  # Create scores dictionary for easy lookup
@@ -148,8 +128,57 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
148
 
149
  return output_pdb
150
 
151
- def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
 
 
 
 
 
 
 
 
 
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Determine if input is a PDB ID or file path
154
  if pdb_id_or_file.endswith('.pdb'):
155
  pdb_path = pdb_id_or_file
@@ -158,51 +187,37 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
158
  pdb_id = pdb_id_or_file
159
  pdb_path = fetch_pdb(pdb_id)
160
 
161
- if not pdb_path:
162
- return "Failed to fetch PDB file", None, None
163
-
164
  # Determine the file format and choose the appropriate parser
165
  _, ext = os.path.splitext(pdb_path)
166
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
167
 
168
- try:
169
- # Parse the structure file
170
- structure = parser.get_structure('protein', pdb_path)
171
- except Exception as e:
172
- return f"Error parsing structure file: {e}", None, None
173
 
174
  # Extract the specified chain
175
- try:
176
- chain = structure[0][segment]
177
- except KeyError:
178
- return "Invalid Chain ID", None, None
179
 
180
  protein_residues = [res for res in chain if is_aa(res)]
181
  sequence = "".join(seq1(res.resname) for res in protein_residues)
182
  sequence_id = [res.id[1] for res in protein_residues]
183
 
184
- visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
185
- if sequence != visualized_sequence:
186
- raise ValueError("The visualized sequence does not match the prediction sequence")
187
-
188
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
189
  with torch.no_grad():
190
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
191
 
192
  # Calculate scores and normalize them
193
- scores = expit(outputs[:, 1] - outputs[:, 0])
194
- normalized_scores = normalize_scores(scores)
195
 
196
  # Choose which scores to use based on score_type
197
- display_scores = normalized_scores if score_type == 'normalized' else scores
198
 
199
  # Zip residues with scores to track the residue ID and score
200
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
201
 
202
  # Also save both score types for later use
203
- raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, scores)]
204
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
205
-
206
 
207
  # Define the score brackets
208
  score_brackets = {
@@ -223,79 +238,35 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
223
  residues_by_bracket[bracket].append(resi)
224
  break
225
 
226
- # Preparing the result string
227
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
228
- result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
229
- result_str += "Residues by Score Brackets:\n\n"
230
 
231
- # Add residues for each bracket
232
- for bracket, residues in residues_by_bracket.items():
233
- result_str += f"Bracket {bracket}:\n"
234
- result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n"
235
- result_str += "\n".join([
236
- f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
237
- for i, res in enumerate(protein_residues) if res.id[1] in residues
238
- ])
239
- result_str += "\n\n"
240
-
241
  # Create chain-specific PDB with scores in B-factor
242
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
243
 
244
  # Molecule visualization with updated script with color mapping
245
- mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)
246
-
247
- # Improved PyMOL command suggestions
248
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
249
- pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
250
-
251
- pymol_commands += f"""
252
- # PyMOL Visualization Commands
253
- fetch {pdb_id}, protein
254
- hide everything, all
255
- show cartoon, chain {segment}
256
- color white, chain {segment}
257
- """
258
 
259
- # Define colors for each score bracket
260
- bracket_colors = {
261
- "0.0-0.2": "white",
262
- "0.2-0.4": "lightorange",
263
- "0.4-0.6": "orange",
264
- "0.6-0.8": "red",
265
- "0.8-1.0": "firebrick"
266
- }
267
-
268
- # Add PyMOL commands for each score bracket
269
- for bracket, residues in residues_by_bracket.items():
270
- if residues: # Only add commands if there are residues in this bracket
271
- color = bracket_colors[bracket]
272
- resi_list = '+'.join(map(str, residues))
273
- pymol_commands += f"""
274
- select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
275
- show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
276
- color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
277
- """
278
- # Create prediction and scored PDB files
279
- prediction_file = f"{pdb_id}_binding_site_residues.txt"
280
  with open(prediction_file, "w") as f:
281
  f.write(result_str)
282
 
283
- return pymol_commands, mol_vis, [prediction_file, scored_pdb],raw_residue_scores,norm_residue_scores
284
-
 
 
285
 
286
  def molecule(input_pdb, residue_scores=None, segment='A'):
287
- # Check if the file exists
288
- if not os.path.isfile(input_pdb):
289
- return f"<p>Error: PDB file not found at {input_pdb}</p>"
290
-
291
- try:
292
- # Read PDB file content
293
- mol = read_mol(input_pdb)
294
- except Exception as e:
295
- return f"<p>Error reading PDB file: {str(e)}</p>"
296
- # More granular scoring for visualization
297
- #mol = read_mol(input_pdb) # Read PDB file content
298
-
299
  # Prepare high-scoring residues script if scores are provided
300
  high_score_script = ""
301
  if residue_scores is not None:
@@ -491,9 +462,9 @@ with gr.Blocks(css="""
491
  Score dependent colorcoding:
492
  - 0.0-0.2: white
493
  - 0.2–0.4: light orange
494
- - 0.4–0.6: orange
495
- - 0.6–0.8: red
496
- - 0.8–1.0: firebrick
497
  """)
498
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
499
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
@@ -504,6 +475,7 @@ with gr.Blocks(css="""
504
  norm_scores_state = gr.State(None)
505
  last_pdb_path = gr.State(None)
506
  last_segment = gr.State(None)
 
507
 
508
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
509
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
@@ -511,12 +483,10 @@ with gr.Blocks(css="""
511
  # First get the actual PDB file path
512
  if mode == "PDB ID":
513
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
514
- if not pdb_path:
515
- return "Failed to fetch PDB file", None, None, None, None, None, None
516
 
517
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
518
  # Store the actual file path, not just the PDB ID
519
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
520
  elif mode == "Upload File":
521
  _, ext = os.path.splitext(pdb_file.name)
522
  file_path = os.path.join('./', f"{_}{ext}")
@@ -525,24 +495,70 @@ with gr.Blocks(css="""
525
  else:
526
  pdb_path = file_path
527
 
528
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
529
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
530
- else:
531
- return "Error: Invalid mode selected", None, None, None, None, None, None
532
-
533
- def update_visualization(score_type_val, raw_scores, norm_scores, pdb_path, segment):
534
- if raw_scores is None or norm_scores is None or pdb_path is None or segment is None:
535
- return None
536
-
537
- # Verify the file exists
538
- if not os.path.exists(pdb_path):
539
- return f"Error: File not found at {pdb_path}"
540
 
541
  # Choose scores based on radio button selection
542
- selected_scores = norm_scores if score_type_val == "Normalized Scores" else raw_scores
 
543
 
544
  # Generate visualization with selected scores
545
- return molecule(pdb_path, selected_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  def fetch_interface(mode, pdb_id, pdb_file):
548
  if mode == "PDB ID":
@@ -555,8 +571,6 @@ with gr.Blocks(css="""
555
  else:
556
  pdb_path= file_path
557
  return pdb_path
558
- else:
559
- return "Error: Invalid mode selected"
560
 
561
  def toggle_mode(selected_mode):
562
  if selected_mode == "PDB ID":
@@ -574,14 +588,14 @@ with gr.Blocks(css="""
574
  process_interface,
575
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
576
  outputs=[predictions_output, molecule_output, download_output,
577
- raw_scores_state, norm_scores_state, last_pdb_path, last_segment]
578
  )
579
 
580
- # Update visualization when score type changes
581
  score_type.change(
582
- update_visualization,
583
- inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment],
584
- outputs=[molecule_output]
585
  )
586
 
587
  visualize_btn.click(
@@ -600,4 +614,4 @@ with gr.Blocks(css="""
600
  inputs=[pdb_input, segment_input],
601
  outputs=[predictions_output, molecule_output, download_output]
602
  )
603
- demo.launch(share=True)
 
27
 
28
  from scipy.special import expit
29
 
 
30
  # Load model and move to device
 
 
31
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
32
  max_length = 1500
33
  model, tokenizer = load_model(checkpoint, max_length)
 
42
 
43
  def read_mol(pdb_path):
44
  """Read PDB file and return its content as a string"""
45
+ with open(pdb_path, 'r') as f:
46
+ return f.read()
47
+
48
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
 
 
 
 
 
 
 
49
  """
50
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
51
  If a structure file already exists locally, it uses that.
52
  """
53
  file_path = download_structure(pdb_id, output_dir)
54
+ return file_path
 
 
 
55
 
56
+ def download_structure(pdb_id: str, output_dir: str) -> str:
57
  """
58
  Attempt to download the structure file in CIF or PDB format.
59
+ Returns the path to the downloaded file.
60
  """
61
  for ext in ['.cif', '.pdb']:
62
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
63
  if os.path.exists(file_path):
64
  return file_path
65
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
66
+ response = requests.get(url, timeout=10)
67
+ if response.status_code == 200:
68
+ with open(file_path, 'wb') as f:
69
+ f.write(response.content)
70
+ return file_path
 
 
 
71
  return None
72
 
73
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
 
84
 
85
  def fetch_pdb(pdb_id):
86
  pdb_path = fetch_structure(pdb_id)
 
 
87
  _, ext = os.path.splitext(pdb_path)
88
  if ext == '.cif':
89
  pdb_path = convert_cif_to_pdb(pdb_path)
 
93
  """
94
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
95
  """
 
96
  parser = PDBParser(QUIET=True)
97
  structure = parser.get_structure('protein', input_pdb)
98
 
 
99
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
100
 
101
  # Create scores dictionary for easy lookup
 
128
 
129
  return output_pdb
130
 
131
+ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
132
+ """Generate PyMOL commands based on score type"""
133
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
134
+
135
+ pymol_commands += f"""
136
+ # PyMOL Visualization Commands
137
+ fetch {pdb_id}, protein
138
+ hide everything, all
139
+ show cartoon, chain {segment}
140
+ color white, chain {segment}
141
+ """
142
 
143
+ # Define colors for each score bracket
144
+ bracket_colors = {
145
+ "0.0-0.2": "white",
146
+ "0.2-0.4": "lightorange",
147
+ "0.4-0.6": "yelloworange",
148
+ "0.6-0.8": "orange",
149
+ "0.8-1.0": "red"
150
+ }
151
+
152
+ # Add PyMOL commands for each score bracket
153
+ for bracket, residues in residues_by_bracket.items():
154
+ if residues: # Only add commands if there are residues in this bracket
155
+ color = bracket_colors[bracket]
156
+ resi_list = '+'.join(map(str, residues))
157
+ pymol_commands += f"""
158
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
159
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
160
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
161
+ """
162
+ return pymol_commands
163
+
164
+ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type):
165
+ """Generate results text based on score type"""
166
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
167
+ result_str += "Residues by Score Brackets:\n\n"
168
+
169
+ # Add residues for each bracket
170
+ for bracket, residues in residues_by_bracket.items():
171
+ result_str += f"Bracket {bracket}:\n"
172
+ result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n"
173
+ result_str += "\n".join([
174
+ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
175
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
176
+ ])
177
+ result_str += "\n\n"
178
+
179
+ return result_str
180
+
181
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
182
  # Determine if input is a PDB ID or file path
183
  if pdb_id_or_file.endswith('.pdb'):
184
  pdb_path = pdb_id_or_file
 
187
  pdb_id = pdb_id_or_file
188
  pdb_path = fetch_pdb(pdb_id)
189
 
 
 
 
190
  # Determine the file format and choose the appropriate parser
191
  _, ext = os.path.splitext(pdb_path)
192
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
193
 
194
+ # Parse the structure file
195
+ structure = parser.get_structure('protein', pdb_path)
 
 
 
196
 
197
  # Extract the specified chain
198
+ chain = structure[0][segment]
 
 
 
199
 
200
  protein_residues = [res for res in chain if is_aa(res)]
201
  sequence = "".join(seq1(res.resname) for res in protein_residues)
202
  sequence_id = [res.id[1] for res in protein_residues]
203
 
 
 
 
 
204
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
205
  with torch.no_grad():
206
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
207
 
208
  # Calculate scores and normalize them
209
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
210
+ normalized_scores = normalize_scores(raw_scores)
211
 
212
  # Choose which scores to use based on score_type
213
+ display_scores = normalized_scores if score_type == 'normalized' else raw_scores
214
 
215
  # Zip residues with scores to track the residue ID and score
216
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
217
 
218
  # Also save both score types for later use
219
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)]
220
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
 
221
 
222
  # Define the score brackets
223
  score_brackets = {
 
238
  residues_by_bracket[bracket].append(resi)
239
  break
240
 
241
+ # Generate timestamp
242
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
243
 
244
+ # Generate result text and PyMOL commands based on score type
245
+ display_score_type = "Normalized" if score_type == 'normalized' else "Raw"
246
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
247
+ display_scores, current_time, display_score_type)
248
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
249
+
 
 
 
 
250
  # Create chain-specific PDB with scores in B-factor
251
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
252
 
253
  # Molecule visualization with updated script with color mapping
254
+ mol_vis = molecule(pdb_path, residue_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ # Create prediction file
257
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  with open(prediction_file, "w") as f:
259
  f.write(result_str)
260
 
261
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
262
+ os.rename(scored_pdb, scored_pdb_name)
263
+
264
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
265
 
266
  def molecule(input_pdb, residue_scores=None, segment='A'):
267
+ # Read PDB file content
268
+ mol = read_mol(input_pdb)
269
+
 
 
 
 
 
 
 
 
 
270
  # Prepare high-scoring residues script if scores are provided
271
  high_score_script = ""
272
  if residue_scores is not None:
 
462
  Score dependent colorcoding:
463
  - 0.0-0.2: white
464
  - 0.2–0.4: light orange
465
+ - 0.4–0.6: yellow orange
466
+ - 0.6–0.8: orange
467
+ - 0.8–1.0: red
468
  """)
469
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
470
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
 
475
  norm_scores_state = gr.State(None)
476
  last_pdb_path = gr.State(None)
477
  last_segment = gr.State(None)
478
+ last_pdb_id = gr.State(None)
479
 
480
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
481
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
 
483
  # First get the actual PDB file path
484
  if mode == "PDB ID":
485
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
 
 
486
 
487
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
488
  # Store the actual file path, not just the PDB ID
489
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
490
  elif mode == "Upload File":
491
  _, ext = os.path.splitext(pdb_file.name)
492
  file_path = os.path.join('./', f"{_}{ext}")
 
495
  else:
496
  pdb_path = file_path
497
 
498
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
499
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
500
+
501
+ def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
502
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
503
+ return None, None, None
 
 
 
 
 
 
504
 
505
  # Choose scores based on radio button selection
506
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
507
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
508
 
509
  # Generate visualization with selected scores
510
+ mol_vis = molecule(pdb_path, selected_scores, segment)
511
+
512
+ # Generate PyMOL commands and downloadable files
513
+ # Get structure for residue info
514
+ _, ext = os.path.splitext(pdb_path)
515
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
516
+ structure = parser.get_structure('protein', pdb_path)
517
+ chain = structure[0][segment]
518
+ protein_residues = [res for res in chain if is_aa(res)]
519
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
520
+
521
+ # Define score brackets
522
+ score_brackets = {
523
+ "0.0-0.2": (0.0, 0.2),
524
+ "0.2-0.4": (0.2, 0.4),
525
+ "0.4-0.6": (0.4, 0.6),
526
+ "0.6-0.8": (0.6, 0.8),
527
+ "0.8-1.0": (0.8, 1.0)
528
+ }
529
+
530
+ # Initialize a dictionary to store residues by bracket
531
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
532
+
533
+ # Categorize residues into brackets
534
+ for resi, score in selected_scores:
535
+ for bracket, (lower, upper) in score_brackets.items():
536
+ if lower <= score < upper:
537
+ residues_by_bracket[bracket].append(resi)
538
+ break
539
+
540
+ # Generate timestamp
541
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
542
+
543
+ # Generate result text and PyMOL commands based on score type
544
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
545
+ scores_array = [score for _, score in selected_scores]
546
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
547
+ scores_array, current_time, display_score_type)
548
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
549
+
550
+ # Create chain-specific PDB with scores in B-factor
551
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
552
+
553
+ # Create prediction file
554
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
555
+ with open(prediction_file, "w") as f:
556
+ f.write(result_str)
557
+
558
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
559
+ os.rename(scored_pdb, scored_pdb_name)
560
+
561
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
562
 
563
  def fetch_interface(mode, pdb_id, pdb_file):
564
  if mode == "PDB ID":
 
571
  else:
572
  pdb_path= file_path
573
  return pdb_path
 
 
574
 
575
  def toggle_mode(selected_mode):
576
  if selected_mode == "PDB ID":
 
588
  process_interface,
589
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
590
  outputs=[predictions_output, molecule_output, download_output,
591
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id]
592
  )
593
 
594
+ # Update visualization, PyMOL commands, and files when score type changes
595
  score_type.change(
596
+ update_visualization_and_files,
597
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id],
598
+ outputs=[molecule_output, predictions_output, download_output]
599
  )
600
 
601
  visualize_btn.click(
 
614
  inputs=[pdb_input, segment_input],
615
  outputs=[predictions_output, molecule_output, download_output]
616
  )
617
+ demo.launch(share=True)
app-Copy1.py CHANGED
@@ -45,8 +45,15 @@ def normalize_scores(scores):
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
- with open(pdb_path, 'r') as f:
49
- return f.read()
 
 
 
 
 
 
 
50
 
51
  def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
52
  """
@@ -141,7 +148,8 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
141
 
142
  return output_pdb
143
 
144
- def process_pdb(pdb_id_or_file, segment):
 
145
  # Determine if input is a PDB ID or file path
146
  if pdb_id_or_file.endswith('.pdb'):
147
  pdb_path = pdb_id_or_file
@@ -180,14 +188,20 @@ def process_pdb(pdb_id_or_file, segment):
180
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
181
  with torch.no_grad():
182
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
183
-
184
  # Calculate scores and normalize them
185
  scores = expit(outputs[:, 1] - outputs[:, 0])
186
-
187
  normalized_scores = normalize_scores(scores)
188
 
 
 
 
189
  # Zip residues with scores to track the residue ID and score
190
- residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
 
 
 
 
191
 
192
 
193
  # Define the score brackets
@@ -236,7 +250,7 @@ def process_pdb(pdb_id_or_file, segment):
236
 
237
  pymol_commands += f"""
238
  # PyMOL Visualization Commands
239
- load {os.path.abspath(pdb_path)}, protein
240
  hide everything, all
241
  show cartoon, chain {segment}
242
  color white, chain {segment}
@@ -266,11 +280,21 @@ def process_pdb(pdb_id_or_file, segment):
266
  with open(prediction_file, "w") as f:
267
  f.write(result_str)
268
 
269
- return pymol_commands, mol_vis, [prediction_file,scored_pdb]
 
270
 
271
  def molecule(input_pdb, residue_scores=None, segment='A'):
 
 
 
 
 
 
 
 
 
272
  # More granular scoring for visualization
273
- mol = read_mol(input_pdb) # Read PDB file content
274
 
275
  # Prepare high-scoring residues script if scores are provided
276
  high_score_script = ""
@@ -410,7 +434,6 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
410
  # Return the HTML content within an iframe safely encoded for special characters
411
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
412
 
413
- # Gradio UI
414
  with gr.Blocks(css="""
415
  /* Customize Gradio button colors */
416
  #visualize-btn, #predict-btn {
@@ -455,32 +478,71 @@ with gr.Blocks(css="""
455
  info="Choose in which chain to predict binding sites.")
456
  prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
457
 
 
 
 
 
 
 
 
 
458
  molecule_output = gr.HTML(label="Protein Structure")
459
  explanation_vis = gr.Markdown("""
460
  Score dependent colorcoding:
461
  - 0.0-0.2: white
462
  - 0.2–0.4: light orange
463
  - 0.4–0.6: orange
464
- - 0.6–0.8: orangered
465
- - 0.8–1.0: red
466
  """)
467
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
468
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
469
  download_output = gr.File(label="Download Files", file_count="multiple")
470
 
471
- def process_interface(mode, pdb_id, pdb_file, chain_id):
 
 
 
 
 
 
 
 
 
472
  if mode == "PDB ID":
473
- return process_pdb(pdb_id, chain_id)
 
 
 
 
 
 
474
  elif mode == "Upload File":
475
  _, ext = os.path.splitext(pdb_file.name)
476
  file_path = os.path.join('./', f"{_}{ext}")
477
  if ext == '.cif':
478
  pdb_path = convert_cif_to_pdb(file_path)
479
  else:
480
- pdb_path= file_path
481
- return process_pdb(pdb_path, chain_id)
 
 
482
  else:
483
- return "Error: Invalid mode selected", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  def fetch_interface(mode, pdb_id, pdb_file):
486
  if mode == "PDB ID":
@@ -488,12 +550,10 @@ with gr.Blocks(css="""
488
  elif mode == "Upload File":
489
  _, ext = os.path.splitext(pdb_file.name)
490
  file_path = os.path.join('./', f"{_}{ext}")
491
- #print(ext)
492
  if ext == '.cif':
493
  pdb_path = convert_cif_to_pdb(file_path)
494
  else:
495
  pdb_path= file_path
496
- #print(pdb_path)
497
  return pdb_path
498
  else:
499
  return "Error: Invalid mode selected"
@@ -512,8 +572,16 @@ with gr.Blocks(css="""
512
 
513
  prediction_btn.click(
514
  process_interface,
515
- inputs=[mode, pdb_input, pdb_file, segment_input],
516
- outputs=[predictions_output, molecule_output, download_output]
 
 
 
 
 
 
 
 
517
  )
518
 
519
  visualize_btn.click(
@@ -532,5 +600,4 @@ with gr.Blocks(css="""
532
  inputs=[pdb_input, segment_input],
533
  outputs=[predictions_output, molecule_output, download_output]
534
  )
535
-
536
  demo.launch(share=True)
 
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
+ try:
49
+ with open(pdb_path, 'r') as f:
50
+ return f.read()
51
+ except FileNotFoundError:
52
+ print(f"File not found: {pdb_path}")
53
+ raise
54
+ except Exception as e:
55
+ print(f"Error reading file {pdb_path}: {str(e)}")
56
+ raise
57
 
58
  def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
59
  """
 
148
 
149
  return output_pdb
150
 
151
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
152
+
153
  # Determine if input is a PDB ID or file path
154
  if pdb_id_or_file.endswith('.pdb'):
155
  pdb_path = pdb_id_or_file
 
188
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
189
  with torch.no_grad():
190
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
191
+
192
  # Calculate scores and normalize them
193
  scores = expit(outputs[:, 1] - outputs[:, 0])
 
194
  normalized_scores = normalize_scores(scores)
195
 
196
+ # Choose which scores to use based on score_type
197
+ display_scores = normalized_scores if score_type == 'normalized' else scores
198
+
199
  # Zip residues with scores to track the residue ID and score
200
+ residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
201
+
202
+ # Also save both score types for later use
203
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, scores)]
204
+ norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
205
 
206
 
207
  # Define the score brackets
 
250
 
251
  pymol_commands += f"""
252
  # PyMOL Visualization Commands
253
+ fetch {pdb_id}, protein
254
  hide everything, all
255
  show cartoon, chain {segment}
256
  color white, chain {segment}
 
280
  with open(prediction_file, "w") as f:
281
  f.write(result_str)
282
 
283
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb],raw_residue_scores,norm_residue_scores
284
+
285
 
286
  def molecule(input_pdb, residue_scores=None, segment='A'):
287
+ # Check if the file exists
288
+ if not os.path.isfile(input_pdb):
289
+ return f"<p>Error: PDB file not found at {input_pdb}</p>"
290
+
291
+ try:
292
+ # Read PDB file content
293
+ mol = read_mol(input_pdb)
294
+ except Exception as e:
295
+ return f"<p>Error reading PDB file: {str(e)}</p>"
296
  # More granular scoring for visualization
297
+ #mol = read_mol(input_pdb) # Read PDB file content
298
 
299
  # Prepare high-scoring residues script if scores are provided
300
  high_score_script = ""
 
434
  # Return the HTML content within an iframe safely encoded for special characters
435
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
436
 
 
437
  with gr.Blocks(css="""
438
  /* Customize Gradio button colors */
439
  #visualize-btn, #predict-btn {
 
478
  info="Choose in which chain to predict binding sites.")
479
  prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
480
 
481
+ # Add score type selector
482
+ score_type = gr.Radio(
483
+ choices=["Normalized Scores", "Raw Scores"],
484
+ value="Normalized Scores",
485
+ label="Score Visualization Type",
486
+ info="Choose which score type to visualize"
487
+ )
488
+
489
  molecule_output = gr.HTML(label="Protein Structure")
490
  explanation_vis = gr.Markdown("""
491
  Score dependent colorcoding:
492
  - 0.0-0.2: white
493
  - 0.2–0.4: light orange
494
  - 0.4–0.6: orange
495
+ - 0.6–0.8: red
496
+ - 0.8–1.0: firebrick
497
  """)
498
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
499
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
500
  download_output = gr.File(label="Download Files", file_count="multiple")
501
 
502
+ # Store these as state variables so we can switch between them
503
+ raw_scores_state = gr.State(None)
504
+ norm_scores_state = gr.State(None)
505
+ last_pdb_path = gr.State(None)
506
+ last_segment = gr.State(None)
507
+
508
+ def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
509
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
510
+
511
+ # First get the actual PDB file path
512
  if mode == "PDB ID":
513
+ pdb_path = fetch_pdb(pdb_id) # Get the actual file path
514
+ if not pdb_path:
515
+ return "Failed to fetch PDB file", None, None, None, None, None, None
516
+
517
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
518
+ # Store the actual file path, not just the PDB ID
519
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
520
  elif mode == "Upload File":
521
  _, ext = os.path.splitext(pdb_file.name)
522
  file_path = os.path.join('./', f"{_}{ext}")
523
  if ext == '.cif':
524
  pdb_path = convert_cif_to_pdb(file_path)
525
  else:
526
+ pdb_path = file_path
527
+
528
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
529
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
530
  else:
531
+ return "Error: Invalid mode selected", None, None, None, None, None, None
532
+
533
+ def update_visualization(score_type_val, raw_scores, norm_scores, pdb_path, segment):
534
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None:
535
+ return None
536
+
537
+ # Verify the file exists
538
+ if not os.path.exists(pdb_path):
539
+ return f"Error: File not found at {pdb_path}"
540
+
541
+ # Choose scores based on radio button selection
542
+ selected_scores = norm_scores if score_type_val == "Normalized Scores" else raw_scores
543
+
544
+ # Generate visualization with selected scores
545
+ return molecule(pdb_path, selected_scores, segment)
546
 
547
  def fetch_interface(mode, pdb_id, pdb_file):
548
  if mode == "PDB ID":
 
550
  elif mode == "Upload File":
551
  _, ext = os.path.splitext(pdb_file.name)
552
  file_path = os.path.join('./', f"{_}{ext}")
 
553
  if ext == '.cif':
554
  pdb_path = convert_cif_to_pdb(file_path)
555
  else:
556
  pdb_path= file_path
 
557
  return pdb_path
558
  else:
559
  return "Error: Invalid mode selected"
 
572
 
573
  prediction_btn.click(
574
  process_interface,
575
+ inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
576
+ outputs=[predictions_output, molecule_output, download_output,
577
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment]
578
+ )
579
+
580
+ # Update visualization when score type changes
581
+ score_type.change(
582
+ update_visualization,
583
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment],
584
+ outputs=[molecule_output]
585
  )
586
 
587
  visualize_btn.click(
 
600
  inputs=[pdb_input, segment_input],
601
  outputs=[predictions_output, molecule_output, download_output]
602
  )
 
603
  demo.launch(share=True)
app.py CHANGED
@@ -27,10 +27,7 @@ from datasets import Dataset
27
 
28
  from scipy.special import expit
29
 
30
-
31
  # Load model and move to device
32
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
- #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
34
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
@@ -45,45 +42,32 @@ def normalize_scores(scores):
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
- try:
49
- with open(pdb_path, 'r') as f:
50
- return f.read()
51
- except FileNotFoundError:
52
- print(f"File not found: {pdb_path}")
53
- raise
54
- except Exception as e:
55
- print(f"Error reading file {pdb_path}: {str(e)}")
56
- raise
57
-
58
- def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
59
  """
60
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
61
  If a structure file already exists locally, it uses that.
62
  """
63
  file_path = download_structure(pdb_id, output_dir)
64
- if file_path:
65
- return file_path
66
- else:
67
- return None
68
 
69
- def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:
70
  """
71
  Attempt to download the structure file in CIF or PDB format.
72
- Returns the path to the downloaded file, or None if download fails.
73
  """
74
  for ext in ['.cif', '.pdb']:
75
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
76
  if os.path.exists(file_path):
77
  return file_path
78
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
79
- try:
80
- response = requests.get(url, timeout=10)
81
- if response.status_code == 200:
82
- with open(file_path, 'wb') as f:
83
- f.write(response.content)
84
- return file_path
85
- except Exception as e:
86
- print(f"Download error for {pdb_id}{ext}: {e}")
87
  return None
88
 
89
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
@@ -100,8 +84,6 @@ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
100
 
101
  def fetch_pdb(pdb_id):
102
  pdb_path = fetch_structure(pdb_id)
103
- if not pdb_path:
104
- return None
105
  _, ext = os.path.splitext(pdb_path)
106
  if ext == '.cif':
107
  pdb_path = convert_cif_to_pdb(pdb_path)
@@ -111,11 +93,9 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
111
  """
112
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
113
  """
114
- # Read the original PDB file
115
  parser = PDBParser(QUIET=True)
116
  structure = parser.get_structure('protein', input_pdb)
117
 
118
- # Prepare a new structure with only the specified chain and selected residues
119
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
120
 
121
  # Create scores dictionary for easy lookup
@@ -148,8 +128,57 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
148
 
149
  return output_pdb
150
 
151
- def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
 
 
 
 
 
 
 
 
 
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Determine if input is a PDB ID or file path
154
  if pdb_id_or_file.endswith('.pdb'):
155
  pdb_path = pdb_id_or_file
@@ -158,51 +187,37 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
158
  pdb_id = pdb_id_or_file
159
  pdb_path = fetch_pdb(pdb_id)
160
 
161
- if not pdb_path:
162
- return "Failed to fetch PDB file", None, None
163
-
164
  # Determine the file format and choose the appropriate parser
165
  _, ext = os.path.splitext(pdb_path)
166
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
167
 
168
- try:
169
- # Parse the structure file
170
- structure = parser.get_structure('protein', pdb_path)
171
- except Exception as e:
172
- return f"Error parsing structure file: {e}", None, None
173
 
174
  # Extract the specified chain
175
- try:
176
- chain = structure[0][segment]
177
- except KeyError:
178
- return "Invalid Chain ID", None, None
179
 
180
  protein_residues = [res for res in chain if is_aa(res)]
181
  sequence = "".join(seq1(res.resname) for res in protein_residues)
182
  sequence_id = [res.id[1] for res in protein_residues]
183
 
184
- visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
185
- if sequence != visualized_sequence:
186
- raise ValueError("The visualized sequence does not match the prediction sequence")
187
-
188
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
189
  with torch.no_grad():
190
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
191
 
192
  # Calculate scores and normalize them
193
- scores = expit(outputs[:, 1] - outputs[:, 0])
194
- normalized_scores = normalize_scores(scores)
195
 
196
  # Choose which scores to use based on score_type
197
- display_scores = normalized_scores if score_type == 'normalized' else scores
198
 
199
  # Zip residues with scores to track the residue ID and score
200
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
201
 
202
  # Also save both score types for later use
203
- raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, scores)]
204
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
205
-
206
 
207
  # Define the score brackets
208
  score_brackets = {
@@ -223,79 +238,35 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
223
  residues_by_bracket[bracket].append(resi)
224
  break
225
 
226
- # Preparing the result string
227
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
228
- result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
229
- result_str += "Residues by Score Brackets:\n\n"
230
 
231
- # Add residues for each bracket
232
- for bracket, residues in residues_by_bracket.items():
233
- result_str += f"Bracket {bracket}:\n"
234
- result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n"
235
- result_str += "\n".join([
236
- f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
237
- for i, res in enumerate(protein_residues) if res.id[1] in residues
238
- ])
239
- result_str += "\n\n"
240
-
241
  # Create chain-specific PDB with scores in B-factor
242
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
243
 
244
  # Molecule visualization with updated script with color mapping
245
- mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)
246
-
247
- # Improved PyMOL command suggestions
248
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
249
- pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
250
-
251
- pymol_commands += f"""
252
- # PyMOL Visualization Commands
253
- fetch {pdb_id}, protein
254
- hide everything, all
255
- show cartoon, chain {segment}
256
- color white, chain {segment}
257
- """
258
 
259
- # Define colors for each score bracket
260
- bracket_colors = {
261
- "0.0-0.2": "white",
262
- "0.2-0.4": "lightorange",
263
- "0.4-0.6": "orange",
264
- "0.6-0.8": "red",
265
- "0.8-1.0": "firebrick"
266
- }
267
-
268
- # Add PyMOL commands for each score bracket
269
- for bracket, residues in residues_by_bracket.items():
270
- if residues: # Only add commands if there are residues in this bracket
271
- color = bracket_colors[bracket]
272
- resi_list = '+'.join(map(str, residues))
273
- pymol_commands += f"""
274
- select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
275
- show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
276
- color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
277
- """
278
- # Create prediction and scored PDB files
279
- prediction_file = f"{pdb_id}_binding_site_residues.txt"
280
  with open(prediction_file, "w") as f:
281
  f.write(result_str)
282
 
283
- return pymol_commands, mol_vis, [prediction_file, scored_pdb],raw_residue_scores,norm_residue_scores
284
-
 
 
285
 
286
  def molecule(input_pdb, residue_scores=None, segment='A'):
287
- # Check if the file exists
288
- if not os.path.isfile(input_pdb):
289
- return f"<p>Error: PDB file not found at {input_pdb}</p>"
290
-
291
- try:
292
- # Read PDB file content
293
- mol = read_mol(input_pdb)
294
- except Exception as e:
295
- return f"<p>Error reading PDB file: {str(e)}</p>"
296
- # More granular scoring for visualization
297
- #mol = read_mol(input_pdb) # Read PDB file content
298
-
299
  # Prepare high-scoring residues script if scores are provided
300
  high_score_script = ""
301
  if residue_scores is not None:
@@ -491,9 +462,9 @@ with gr.Blocks(css="""
491
  Score dependent colorcoding:
492
  - 0.0-0.2: white
493
  - 0.2–0.4: light orange
494
- - 0.4–0.6: orange
495
- - 0.6–0.8: red
496
- - 0.8–1.0: firebrick
497
  """)
498
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
499
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
@@ -504,6 +475,7 @@ with gr.Blocks(css="""
504
  norm_scores_state = gr.State(None)
505
  last_pdb_path = gr.State(None)
506
  last_segment = gr.State(None)
 
507
 
508
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
509
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
@@ -511,12 +483,10 @@ with gr.Blocks(css="""
511
  # First get the actual PDB file path
512
  if mode == "PDB ID":
513
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
514
- if not pdb_path:
515
- return "Failed to fetch PDB file", None, None, None, None, None, None
516
 
517
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
518
  # Store the actual file path, not just the PDB ID
519
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
520
  elif mode == "Upload File":
521
  _, ext = os.path.splitext(pdb_file.name)
522
  file_path = os.path.join('./', f"{_}{ext}")
@@ -525,24 +495,70 @@ with gr.Blocks(css="""
525
  else:
526
  pdb_path = file_path
527
 
528
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
529
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
530
- else:
531
- return "Error: Invalid mode selected", None, None, None, None, None, None
532
-
533
- def update_visualization(score_type_val, raw_scores, norm_scores, pdb_path, segment):
534
- if raw_scores is None or norm_scores is None or pdb_path is None or segment is None:
535
- return None
536
-
537
- # Verify the file exists
538
- if not os.path.exists(pdb_path):
539
- return f"Error: File not found at {pdb_path}"
540
 
541
  # Choose scores based on radio button selection
542
- selected_scores = norm_scores if score_type_val == "Normalized Scores" else raw_scores
 
543
 
544
  # Generate visualization with selected scores
545
- return molecule(pdb_path, selected_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  def fetch_interface(mode, pdb_id, pdb_file):
548
  if mode == "PDB ID":
@@ -555,8 +571,6 @@ with gr.Blocks(css="""
555
  else:
556
  pdb_path= file_path
557
  return pdb_path
558
- else:
559
- return "Error: Invalid mode selected"
560
 
561
  def toggle_mode(selected_mode):
562
  if selected_mode == "PDB ID":
@@ -574,14 +588,14 @@ with gr.Blocks(css="""
574
  process_interface,
575
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
576
  outputs=[predictions_output, molecule_output, download_output,
577
- raw_scores_state, norm_scores_state, last_pdb_path, last_segment]
578
  )
579
 
580
- # Update visualization when score type changes
581
  score_type.change(
582
- update_visualization,
583
- inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment],
584
- outputs=[molecule_output]
585
  )
586
 
587
  visualize_btn.click(
@@ -600,4 +614,4 @@ with gr.Blocks(css="""
600
  inputs=[pdb_input, segment_input],
601
  outputs=[predictions_output, molecule_output, download_output]
602
  )
603
- demo.launch(share=True)
 
27
 
28
  from scipy.special import expit
29
 
 
30
  # Load model and move to device
 
 
31
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
32
  max_length = 1500
33
  model, tokenizer = load_model(checkpoint, max_length)
 
42
 
43
  def read_mol(pdb_path):
44
  """Read PDB file and return its content as a string"""
45
+ with open(pdb_path, 'r') as f:
46
+ return f.read()
47
+
48
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
 
 
 
 
 
 
 
49
  """
50
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
51
  If a structure file already exists locally, it uses that.
52
  """
53
  file_path = download_structure(pdb_id, output_dir)
54
+ return file_path
 
 
 
55
 
56
+ def download_structure(pdb_id: str, output_dir: str) -> str:
57
  """
58
  Attempt to download the structure file in CIF or PDB format.
59
+ Returns the path to the downloaded file.
60
  """
61
  for ext in ['.cif', '.pdb']:
62
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
63
  if os.path.exists(file_path):
64
  return file_path
65
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
66
+ response = requests.get(url, timeout=10)
67
+ if response.status_code == 200:
68
+ with open(file_path, 'wb') as f:
69
+ f.write(response.content)
70
+ return file_path
 
 
 
71
  return None
72
 
73
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
 
84
 
85
  def fetch_pdb(pdb_id):
86
  pdb_path = fetch_structure(pdb_id)
 
 
87
  _, ext = os.path.splitext(pdb_path)
88
  if ext == '.cif':
89
  pdb_path = convert_cif_to_pdb(pdb_path)
 
93
  """
94
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
95
  """
 
96
  parser = PDBParser(QUIET=True)
97
  structure = parser.get_structure('protein', input_pdb)
98
 
 
99
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
100
 
101
  # Create scores dictionary for easy lookup
 
128
 
129
  return output_pdb
130
 
131
+ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
132
+ """Generate PyMOL commands based on score type"""
133
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
134
+
135
+ pymol_commands += f"""
136
+ # PyMOL Visualization Commands
137
+ fetch {pdb_id}, protein
138
+ hide everything, all
139
+ show cartoon, chain {segment}
140
+ color white, chain {segment}
141
+ """
142
 
143
+ # Define colors for each score bracket
144
+ bracket_colors = {
145
+ "0.0-0.2": "white",
146
+ "0.2-0.4": "lightorange",
147
+ "0.4-0.6": "yelloworange",
148
+ "0.6-0.8": "orange",
149
+ "0.8-1.0": "red"
150
+ }
151
+
152
+ # Add PyMOL commands for each score bracket
153
+ for bracket, residues in residues_by_bracket.items():
154
+ if residues: # Only add commands if there are residues in this bracket
155
+ color = bracket_colors[bracket]
156
+ resi_list = '+'.join(map(str, residues))
157
+ pymol_commands += f"""
158
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
159
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
160
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
161
+ """
162
+ return pymol_commands
163
+
164
+ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type):
165
+ """Generate results text based on score type"""
166
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
167
+ result_str += "Residues by Score Brackets:\n\n"
168
+
169
+ # Add residues for each bracket
170
+ for bracket, residues in residues_by_bracket.items():
171
+ result_str += f"Bracket {bracket}:\n"
172
+ result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n"
173
+ result_str += "\n".join([
174
+ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
175
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
176
+ ])
177
+ result_str += "\n\n"
178
+
179
+ return result_str
180
+
181
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
182
  # Determine if input is a PDB ID or file path
183
  if pdb_id_or_file.endswith('.pdb'):
184
  pdb_path = pdb_id_or_file
 
187
  pdb_id = pdb_id_or_file
188
  pdb_path = fetch_pdb(pdb_id)
189
 
 
 
 
190
  # Determine the file format and choose the appropriate parser
191
  _, ext = os.path.splitext(pdb_path)
192
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
193
 
194
+ # Parse the structure file
195
+ structure = parser.get_structure('protein', pdb_path)
 
 
 
196
 
197
  # Extract the specified chain
198
+ chain = structure[0][segment]
 
 
 
199
 
200
  protein_residues = [res for res in chain if is_aa(res)]
201
  sequence = "".join(seq1(res.resname) for res in protein_residues)
202
  sequence_id = [res.id[1] for res in protein_residues]
203
 
 
 
 
 
204
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
205
  with torch.no_grad():
206
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
207
 
208
  # Calculate scores and normalize them
209
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
210
+ normalized_scores = normalize_scores(raw_scores)
211
 
212
  # Choose which scores to use based on score_type
213
+ display_scores = normalized_scores if score_type == 'normalized' else raw_scores
214
 
215
  # Zip residues with scores to track the residue ID and score
216
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
217
 
218
  # Also save both score types for later use
219
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)]
220
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
 
221
 
222
  # Define the score brackets
223
  score_brackets = {
 
238
  residues_by_bracket[bracket].append(resi)
239
  break
240
 
241
+ # Generate timestamp
242
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
243
 
244
+ # Generate result text and PyMOL commands based on score type
245
+ display_score_type = "Normalized" if score_type == 'normalized' else "Raw"
246
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
247
+ display_scores, current_time, display_score_type)
248
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
249
+
 
 
 
 
250
  # Create chain-specific PDB with scores in B-factor
251
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
252
 
253
  # Molecule visualization with updated script with color mapping
254
+ mol_vis = molecule(pdb_path, residue_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ # Create prediction file
257
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  with open(prediction_file, "w") as f:
259
  f.write(result_str)
260
 
261
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
262
+ os.rename(scored_pdb, scored_pdb_name)
263
+
264
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
265
 
266
  def molecule(input_pdb, residue_scores=None, segment='A'):
267
+ # Read PDB file content
268
+ mol = read_mol(input_pdb)
269
+
 
 
 
 
 
 
 
 
 
270
  # Prepare high-scoring residues script if scores are provided
271
  high_score_script = ""
272
  if residue_scores is not None:
 
462
  Score dependent colorcoding:
463
  - 0.0-0.2: white
464
  - 0.2–0.4: light orange
465
+ - 0.4–0.6: yellow orange
466
+ - 0.6–0.8: orange
467
+ - 0.8–1.0: red
468
  """)
469
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
470
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
 
475
  norm_scores_state = gr.State(None)
476
  last_pdb_path = gr.State(None)
477
  last_segment = gr.State(None)
478
+ last_pdb_id = gr.State(None)
479
 
480
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
481
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
 
483
  # First get the actual PDB file path
484
  if mode == "PDB ID":
485
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
 
 
486
 
487
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
488
  # Store the actual file path, not just the PDB ID
489
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
490
  elif mode == "Upload File":
491
  _, ext = os.path.splitext(pdb_file.name)
492
  file_path = os.path.join('./', f"{_}{ext}")
 
495
  else:
496
  pdb_path = file_path
497
 
498
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
499
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
500
+
501
+ def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
502
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
503
+ return None, None, None
 
 
 
 
 
 
504
 
505
  # Choose scores based on radio button selection
506
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
507
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
508
 
509
  # Generate visualization with selected scores
510
+ mol_vis = molecule(pdb_path, selected_scores, segment)
511
+
512
+ # Generate PyMOL commands and downloadable files
513
+ # Get structure for residue info
514
+ _, ext = os.path.splitext(pdb_path)
515
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
516
+ structure = parser.get_structure('protein', pdb_path)
517
+ chain = structure[0][segment]
518
+ protein_residues = [res for res in chain if is_aa(res)]
519
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
520
+
521
+ # Define score brackets
522
+ score_brackets = {
523
+ "0.0-0.2": (0.0, 0.2),
524
+ "0.2-0.4": (0.2, 0.4),
525
+ "0.4-0.6": (0.4, 0.6),
526
+ "0.6-0.8": (0.6, 0.8),
527
+ "0.8-1.0": (0.8, 1.0)
528
+ }
529
+
530
+ # Initialize a dictionary to store residues by bracket
531
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
532
+
533
+ # Categorize residues into brackets
534
+ for resi, score in selected_scores:
535
+ for bracket, (lower, upper) in score_brackets.items():
536
+ if lower <= score < upper:
537
+ residues_by_bracket[bracket].append(resi)
538
+ break
539
+
540
+ # Generate timestamp
541
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
542
+
543
+ # Generate result text and PyMOL commands based on score type
544
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
545
+ scores_array = [score for _, score in selected_scores]
546
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
547
+ scores_array, current_time, display_score_type)
548
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
549
+
550
+ # Create chain-specific PDB with scores in B-factor
551
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
552
+
553
+ # Create prediction file
554
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
555
+ with open(prediction_file, "w") as f:
556
+ f.write(result_str)
557
+
558
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
559
+ os.rename(scored_pdb, scored_pdb_name)
560
+
561
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
562
 
563
  def fetch_interface(mode, pdb_id, pdb_file):
564
  if mode == "PDB ID":
 
571
  else:
572
  pdb_path= file_path
573
  return pdb_path
 
 
574
 
575
  def toggle_mode(selected_mode):
576
  if selected_mode == "PDB ID":
 
588
  process_interface,
589
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
590
  outputs=[predictions_output, molecule_output, download_output,
591
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id]
592
  )
593
 
594
+ # Update visualization, PyMOL commands, and files when score type changes
595
  score_type.change(
596
+ update_visualization_and_files,
597
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id],
598
+ outputs=[molecule_output, predictions_output, download_output]
599
  )
600
 
601
  visualize_btn.click(
 
614
  inputs=[pdb_input, segment_input],
615
  outputs=[predictions_output, molecule_output, download_output]
616
  )
617
+ demo.launch(share=True)