ThorbenFroehlking commited on
Commit
5a58a7d
·
1 Parent(s): 3e01b86
Files changed (2) hide show
  1. .ipynb_checkpoints/app-Copy1-checkpoint.py +197 -112
  2. app-Copy1.py +174 -144
.ipynb_checkpoints/app-Copy1-checkpoint.py CHANGED
@@ -27,11 +27,11 @@ from datasets import Dataset
27
 
28
  from scipy.special import expit
29
 
30
-
31
  # Load model and move to device
32
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
34
- checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
 
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -48,35 +48,29 @@ def read_mol(pdb_path):
48
  with open(pdb_path, 'r') as f:
49
  return f.read()
50
 
51
- def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
52
  """
53
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
54
  If a structure file already exists locally, it uses that.
55
  """
56
  file_path = download_structure(pdb_id, output_dir)
57
- if file_path:
58
- return file_path
59
- else:
60
- return None
61
 
62
- def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:
63
  """
64
  Attempt to download the structure file in CIF or PDB format.
65
- Returns the path to the downloaded file, or None if download fails.
66
  """
67
  for ext in ['.cif', '.pdb']:
68
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
69
  if os.path.exists(file_path):
70
  return file_path
71
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
72
- try:
73
- response = requests.get(url, timeout=10)
74
- if response.status_code == 200:
75
- with open(file_path, 'wb') as f:
76
- f.write(response.content)
77
- return file_path
78
- except Exception as e:
79
- print(f"Download error for {pdb_id}{ext}: {e}")
80
  return None
81
 
82
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
@@ -93,8 +87,6 @@ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
93
 
94
  def fetch_pdb(pdb_id):
95
  pdb_path = fetch_structure(pdb_id)
96
- if not pdb_path:
97
- return None
98
  _, ext = os.path.splitext(pdb_path)
99
  if ext == '.cif':
100
  pdb_path = convert_cif_to_pdb(pdb_path)
@@ -104,11 +96,9 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
104
  """
105
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
106
  """
107
- # Read the original PDB file
108
  parser = PDBParser(QUIET=True)
109
  structure = parser.get_structure('protein', input_pdb)
110
 
111
- # Prepare a new structure with only the specified chain and selected residues
112
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
113
 
114
  # Create scores dictionary for easy lookup
@@ -141,7 +131,57 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
141
 
142
  return output_pdb
143
 
144
- def process_pdb(pdb_id_or_file, segment):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  # Determine if input is a PDB ID or file path
146
  if pdb_id_or_file.endswith('.pdb'):
147
  pdb_path = pdb_id_or_file
@@ -150,45 +190,37 @@ def process_pdb(pdb_id_or_file, segment):
150
  pdb_id = pdb_id_or_file
151
  pdb_path = fetch_pdb(pdb_id)
152
 
153
- if not pdb_path:
154
- return "Failed to fetch PDB file", None, None
155
-
156
  # Determine the file format and choose the appropriate parser
157
  _, ext = os.path.splitext(pdb_path)
158
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
159
 
160
- try:
161
- # Parse the structure file
162
- structure = parser.get_structure('protein', pdb_path)
163
- except Exception as e:
164
- return f"Error parsing structure file: {e}", None, None
165
 
166
  # Extract the specified chain
167
- try:
168
- chain = structure[0][segment]
169
- except KeyError:
170
- return "Invalid Chain ID", None, None
171
 
172
  protein_residues = [res for res in chain if is_aa(res)]
173
  sequence = "".join(seq1(res.resname) for res in protein_residues)
174
  sequence_id = [res.id[1] for res in protein_residues]
175
 
176
- visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
177
- if sequence != visualized_sequence:
178
- raise ValueError("The visualized sequence does not match the prediction sequence")
179
-
180
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
181
  with torch.no_grad():
182
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
183
-
184
  # Calculate scores and normalize them
185
- scores = expit(outputs[:, 1] - outputs[:, 0])
 
186
 
187
- normalized_scores = normalize_scores(scores)
 
188
 
189
  # Zip residues with scores to track the residue ID and score
190
- residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
191
-
 
 
 
192
 
193
  # Define the score brackets
194
  score_brackets = {
@@ -209,69 +241,35 @@ def process_pdb(pdb_id_or_file, segment):
209
  residues_by_bracket[bracket].append(resi)
210
  break
211
 
212
- # Preparing the result string
213
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
214
- result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
215
- result_str += "Residues by Score Brackets:\n\n"
216
 
217
- # Add residues for each bracket
218
- for bracket, residues in residues_by_bracket.items():
219
- result_str += f"Bracket {bracket}:\n"
220
- result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n"
221
- result_str += "\n".join([
222
- f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
223
- for i, res in enumerate(protein_residues) if res.id[1] in residues
224
- ])
225
- result_str += "\n\n"
226
-
227
  # Create chain-specific PDB with scores in B-factor
228
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
229
 
230
  # Molecule visualization with updated script with color mapping
231
- mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)
232
-
233
- # Improved PyMOL command suggestions
234
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
235
- pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
236
-
237
- pymol_commands += f"""
238
- # PyMOL Visualization Commands
239
- load {os.path.abspath(pdb_path)}, protein
240
- hide everything, all
241
- show cartoon, chain {segment}
242
- color white, chain {segment}
243
- """
244
-
245
- # Define colors for each score bracket
246
- bracket_colors = {
247
- "0.0-0.2": "white",
248
- "0.2-0.4": "lightorange",
249
- "0.4-0.6": "orange",
250
- "0.6-0.8": "red",
251
- "0.8-1.0": "firebrick"
252
- }
253
 
254
- # Add PyMOL commands for each score bracket
255
- for bracket, residues in residues_by_bracket.items():
256
- if residues: # Only add commands if there are residues in this bracket
257
- color = bracket_colors[bracket]
258
- resi_list = '+'.join(map(str, residues))
259
- pymol_commands += f"""
260
- select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
261
- show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
262
- color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
263
- """
264
- # Create prediction and scored PDB files
265
- prediction_file = f"{pdb_id}_binding_site_residues.txt"
266
  with open(prediction_file, "w") as f:
267
  f.write(result_str)
268
 
269
- return pymol_commands, mol_vis, [prediction_file,scored_pdb]
 
 
 
270
 
271
  def molecule(input_pdb, residue_scores=None, segment='A'):
272
- # More granular scoring for visualization
273
- mol = read_mol(input_pdb) # Read PDB file content
274
-
275
  # Prepare high-scoring residues script if scores are provided
276
  high_score_script = ""
277
  if residue_scores is not None:
@@ -410,7 +408,6 @@ def molecule(input_pdb, residue_scores=None, segment='A'):
410
  # Return the HTML content within an iframe safely encoded for special characters
411
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
412
 
413
- # Gradio UI
414
  with gr.Blocks(css="""
415
  /* Customize Gradio button colors */
416
  #visualize-btn, #predict-btn {
@@ -455,32 +452,116 @@ with gr.Blocks(css="""
455
  info="Choose in which chain to predict binding sites.")
456
  prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
457
 
 
 
 
 
 
 
 
 
458
  molecule_output = gr.HTML(label="Protein Structure")
459
  explanation_vis = gr.Markdown("""
460
  Score dependent colorcoding:
461
  - 0.0-0.2: white
462
  - 0.2–0.4: light orange
463
- - 0.4–0.6: orange
464
- - 0.6–0.8: orangered
465
  - 0.8–1.0: red
466
  """)
467
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
468
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
469
  download_output = gr.File(label="Download Files", file_count="multiple")
470
 
471
- def process_interface(mode, pdb_id, pdb_file, chain_id):
 
 
 
 
 
 
 
 
 
 
472
  if mode == "PDB ID":
473
- return process_pdb(pdb_id, chain_id)
 
 
 
 
474
  elif mode == "Upload File":
475
  _, ext = os.path.splitext(pdb_file.name)
476
  file_path = os.path.join('./', f"{_}{ext}")
477
  if ext == '.cif':
478
  pdb_path = convert_cif_to_pdb(file_path)
479
  else:
480
- pdb_path= file_path
481
- return process_pdb(pdb_path, chain_id)
482
- else:
483
- return "Error: Invalid mode selected", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  def fetch_interface(mode, pdb_id, pdb_file):
486
  if mode == "PDB ID":
@@ -488,15 +569,11 @@ with gr.Blocks(css="""
488
  elif mode == "Upload File":
489
  _, ext = os.path.splitext(pdb_file.name)
490
  file_path = os.path.join('./', f"{_}{ext}")
491
- #print(ext)
492
  if ext == '.cif':
493
  pdb_path = convert_cif_to_pdb(file_path)
494
  else:
495
  pdb_path= file_path
496
- #print(pdb_path)
497
  return pdb_path
498
- else:
499
- return "Error: Invalid mode selected"
500
 
501
  def toggle_mode(selected_mode):
502
  if selected_mode == "PDB ID":
@@ -512,8 +589,16 @@ with gr.Blocks(css="""
512
 
513
  prediction_btn.click(
514
  process_interface,
515
- inputs=[mode, pdb_input, pdb_file, segment_input],
516
- outputs=[predictions_output, molecule_output, download_output]
 
 
 
 
 
 
 
 
517
  )
518
 
519
  visualize_btn.click(
@@ -527,10 +612,10 @@ with gr.Blocks(css="""
527
  examples=[
528
  ["7RPZ", "A"],
529
  ["2IWI", "B"],
530
- ["7LCJ", "R"]
 
531
  ],
532
  inputs=[pdb_input, segment_input],
533
  outputs=[predictions_output, molecule_output, download_output]
534
  )
535
-
536
- demo.launch(share=True)
 
27
 
28
  from scipy.special import expit
29
 
 
30
  # Load model and move to device
31
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
32
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
33
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
34
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full'
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
48
  with open(pdb_path, 'r') as f:
49
  return f.read()
50
 
51
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
52
  """
53
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
54
  If a structure file already exists locally, it uses that.
55
  """
56
  file_path = download_structure(pdb_id, output_dir)
57
+ return file_path
 
 
 
58
 
59
+ def download_structure(pdb_id: str, output_dir: str) -> str:
60
  """
61
  Attempt to download the structure file in CIF or PDB format.
62
+ Returns the path to the downloaded file.
63
  """
64
  for ext in ['.cif', '.pdb']:
65
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
66
  if os.path.exists(file_path):
67
  return file_path
68
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
69
+ response = requests.get(url, timeout=10)
70
+ if response.status_code == 200:
71
+ with open(file_path, 'wb') as f:
72
+ f.write(response.content)
73
+ return file_path
 
 
 
74
  return None
75
 
76
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
 
87
 
88
  def fetch_pdb(pdb_id):
89
  pdb_path = fetch_structure(pdb_id)
 
 
90
  _, ext = os.path.splitext(pdb_path)
91
  if ext == '.cif':
92
  pdb_path = convert_cif_to_pdb(pdb_path)
 
96
  """
97
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
98
  """
 
99
  parser = PDBParser(QUIET=True)
100
  structure = parser.get_structure('protein', input_pdb)
101
 
 
102
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
103
 
104
  # Create scores dictionary for easy lookup
 
131
 
132
  return output_pdb
133
 
134
+ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
135
+ """Generate PyMOL commands based on score type"""
136
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
137
+
138
+ pymol_commands += f"""
139
+ # PyMOL Visualization Commands
140
+ fetch {pdb_id}, protein
141
+ hide everything, all
142
+ show cartoon, chain {segment}
143
+ color white, chain {segment}
144
+ """
145
+
146
+ # Define colors for each score bracket
147
+ bracket_colors = {
148
+ "0.0-0.2": "white",
149
+ "0.2-0.4": "lightorange",
150
+ "0.4-0.6": "yelloworange",
151
+ "0.6-0.8": "orange",
152
+ "0.8-1.0": "red"
153
+ }
154
+
155
+ # Add PyMOL commands for each score bracket
156
+ for bracket, residues in residues_by_bracket.items():
157
+ if residues: # Only add commands if there are residues in this bracket
158
+ color = bracket_colors[bracket]
159
+ resi_list = '+'.join(map(str, residues))
160
+ pymol_commands += f"""
161
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
162
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
163
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
164
+ """
165
+ return pymol_commands
166
+
167
+ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type):
168
+ """Generate results text based on score type"""
169
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
170
+ result_str += "Residues by Score Brackets:\n\n"
171
+
172
+ # Add residues for each bracket
173
+ for bracket, residues in residues_by_bracket.items():
174
+ result_str += f"Bracket {bracket}:\n"
175
+ result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n"
176
+ result_str += "\n".join([
177
+ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
178
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
179
+ ])
180
+ result_str += "\n\n"
181
+
182
+ return result_str
183
+
184
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
185
  # Determine if input is a PDB ID or file path
186
  if pdb_id_or_file.endswith('.pdb'):
187
  pdb_path = pdb_id_or_file
 
190
  pdb_id = pdb_id_or_file
191
  pdb_path = fetch_pdb(pdb_id)
192
 
 
 
 
193
  # Determine the file format and choose the appropriate parser
194
  _, ext = os.path.splitext(pdb_path)
195
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
196
 
197
+ # Parse the structure file
198
+ structure = parser.get_structure('protein', pdb_path)
 
 
 
199
 
200
  # Extract the specified chain
201
+ chain = structure[0][segment]
 
 
 
202
 
203
  protein_residues = [res for res in chain if is_aa(res)]
204
  sequence = "".join(seq1(res.resname) for res in protein_residues)
205
  sequence_id = [res.id[1] for res in protein_residues]
206
 
 
 
 
 
207
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
208
  with torch.no_grad():
209
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
210
+
211
  # Calculate scores and normalize them
212
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
213
+ normalized_scores = normalize_scores(raw_scores)
214
 
215
+ # Choose which scores to use based on score_type
216
+ display_scores = normalized_scores if score_type == 'normalized' else raw_scores
217
 
218
  # Zip residues with scores to track the residue ID and score
219
+ residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
220
+
221
+ # Also save both score types for later use
222
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)]
223
+ norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
224
 
225
  # Define the score brackets
226
  score_brackets = {
 
241
  residues_by_bracket[bracket].append(resi)
242
  break
243
 
244
+ # Generate timestamp
245
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
246
 
247
+ # Generate result text and PyMOL commands based on score type
248
+ display_score_type = "Normalized" if score_type == 'normalized' else "Raw"
249
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
250
+ display_scores, current_time, display_score_type)
251
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
252
+
 
 
 
 
253
  # Create chain-specific PDB with scores in B-factor
254
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
255
 
256
  # Molecule visualization with updated script with color mapping
257
+ mol_vis = molecule(pdb_path, residue_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ # Create prediction file
260
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
 
 
 
 
 
 
 
 
 
 
261
  with open(prediction_file, "w") as f:
262
  f.write(result_str)
263
 
264
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
265
+ os.rename(scored_pdb, scored_pdb_name)
266
+
267
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
268
 
269
  def molecule(input_pdb, residue_scores=None, segment='A'):
270
+ # Read PDB file content
271
+ mol = read_mol(input_pdb)
272
+
273
  # Prepare high-scoring residues script if scores are provided
274
  high_score_script = ""
275
  if residue_scores is not None:
 
408
  # Return the HTML content within an iframe safely encoded for special characters
409
  return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
410
 
 
411
  with gr.Blocks(css="""
412
  /* Customize Gradio button colors */
413
  #visualize-btn, #predict-btn {
 
452
  info="Choose in which chain to predict binding sites.")
453
  prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
454
 
455
+ # Add score type selector
456
+ score_type = gr.Radio(
457
+ choices=["Normalized Scores", "Raw Scores"],
458
+ value="Normalized Scores",
459
+ label="Score Visualization Type",
460
+ info="Choose which score type to visualize"
461
+ )
462
+
463
  molecule_output = gr.HTML(label="Protein Structure")
464
  explanation_vis = gr.Markdown("""
465
  Score dependent colorcoding:
466
  - 0.0-0.2: white
467
  - 0.2–0.4: light orange
468
+ - 0.4–0.6: yellow orange
469
+ - 0.6–0.8: orange
470
  - 0.8–1.0: red
471
  """)
472
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
473
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
474
  download_output = gr.File(label="Download Files", file_count="multiple")
475
 
476
+ # Store these as state variables so we can switch between them
477
+ raw_scores_state = gr.State(None)
478
+ norm_scores_state = gr.State(None)
479
+ last_pdb_path = gr.State(None)
480
+ last_segment = gr.State(None)
481
+ last_pdb_id = gr.State(None)
482
+
483
+ def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
484
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
485
+
486
+ # First get the actual PDB file path
487
  if mode == "PDB ID":
488
+ pdb_path = fetch_pdb(pdb_id) # Get the actual file path
489
+
490
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
491
+ # Store the actual file path, not just the PDB ID
492
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
493
  elif mode == "Upload File":
494
  _, ext = os.path.splitext(pdb_file.name)
495
  file_path = os.path.join('./', f"{_}{ext}")
496
  if ext == '.cif':
497
  pdb_path = convert_cif_to_pdb(file_path)
498
  else:
499
+ pdb_path = file_path
500
+
501
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
502
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
503
+
504
+ def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
505
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
506
+ return None, None, None
507
+
508
+ # Choose scores based on radio button selection
509
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
510
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
511
+
512
+ # Generate visualization with selected scores
513
+ mol_vis = molecule(pdb_path, selected_scores, segment)
514
+
515
+ # Generate PyMOL commands and downloadable files
516
+ # Get structure for residue info
517
+ _, ext = os.path.splitext(pdb_path)
518
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
519
+ structure = parser.get_structure('protein', pdb_path)
520
+ chain = structure[0][segment]
521
+ protein_residues = [res for res in chain if is_aa(res)]
522
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
523
+
524
+ # Define score brackets
525
+ score_brackets = {
526
+ "0.0-0.2": (0.0, 0.2),
527
+ "0.2-0.4": (0.2, 0.4),
528
+ "0.4-0.6": (0.4, 0.6),
529
+ "0.6-0.8": (0.6, 0.8),
530
+ "0.8-1.0": (0.8, 1.0)
531
+ }
532
+
533
+ # Initialize a dictionary to store residues by bracket
534
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
535
+
536
+ # Categorize residues into brackets
537
+ for resi, score in selected_scores:
538
+ for bracket, (lower, upper) in score_brackets.items():
539
+ if lower <= score < upper:
540
+ residues_by_bracket[bracket].append(resi)
541
+ break
542
+
543
+ # Generate timestamp
544
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
545
+
546
+ # Generate result text and PyMOL commands based on score type
547
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
548
+ scores_array = [score for _, score in selected_scores]
549
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
550
+ scores_array, current_time, display_score_type)
551
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
552
+
553
+ # Create chain-specific PDB with scores in B-factor
554
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
555
+
556
+ # Create prediction file
557
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
558
+ with open(prediction_file, "w") as f:
559
+ f.write(result_str)
560
+
561
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
562
+ os.rename(scored_pdb, scored_pdb_name)
563
+
564
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
565
 
566
  def fetch_interface(mode, pdb_id, pdb_file):
567
  if mode == "PDB ID":
 
569
  elif mode == "Upload File":
570
  _, ext = os.path.splitext(pdb_file.name)
571
  file_path = os.path.join('./', f"{_}{ext}")
 
572
  if ext == '.cif':
573
  pdb_path = convert_cif_to_pdb(file_path)
574
  else:
575
  pdb_path= file_path
 
576
  return pdb_path
 
 
577
 
578
  def toggle_mode(selected_mode):
579
  if selected_mode == "PDB ID":
 
589
 
590
  prediction_btn.click(
591
  process_interface,
592
+ inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
593
+ outputs=[predictions_output, molecule_output, download_output,
594
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id]
595
+ )
596
+
597
+ # Update visualization, PyMOL commands, and files when score type changes
598
+ score_type.change(
599
+ update_visualization_and_files,
600
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id],
601
+ outputs=[molecule_output, predictions_output, download_output]
602
  )
603
 
604
  visualize_btn.click(
 
612
  examples=[
613
  ["7RPZ", "A"],
614
  ["2IWI", "B"],
615
+ ["7LCJ", "R"],
616
+ ["4OBE", "A"]
617
  ],
618
  inputs=[pdb_input, segment_input],
619
  outputs=[predictions_output, molecule_output, download_output]
620
  )
621
+ demo.launch(share=True)
 
app-Copy1.py CHANGED
@@ -27,11 +27,11 @@ from datasets import Dataset
27
 
28
  from scipy.special import expit
29
 
30
-
31
  # Load model and move to device
32
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
33
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
34
- checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
 
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -45,45 +45,32 @@ def normalize_scores(scores):
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
- try:
49
- with open(pdb_path, 'r') as f:
50
- return f.read()
51
- except FileNotFoundError:
52
- print(f"File not found: {pdb_path}")
53
- raise
54
- except Exception as e:
55
- print(f"Error reading file {pdb_path}: {str(e)}")
56
- raise
57
-
58
- def fetch_structure(pdb_id: str, output_dir: str = ".") -> Optional[str]:
59
  """
60
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
61
  If a structure file already exists locally, it uses that.
62
  """
63
  file_path = download_structure(pdb_id, output_dir)
64
- if file_path:
65
- return file_path
66
- else:
67
- return None
68
 
69
- def download_structure(pdb_id: str, output_dir: str) -> Optional[str]:
70
  """
71
  Attempt to download the structure file in CIF or PDB format.
72
- Returns the path to the downloaded file, or None if download fails.
73
  """
74
  for ext in ['.cif', '.pdb']:
75
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
76
  if os.path.exists(file_path):
77
  return file_path
78
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
79
- try:
80
- response = requests.get(url, timeout=10)
81
- if response.status_code == 200:
82
- with open(file_path, 'wb') as f:
83
- f.write(response.content)
84
- return file_path
85
- except Exception as e:
86
- print(f"Download error for {pdb_id}{ext}: {e}")
87
  return None
88
 
89
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
@@ -100,8 +87,6 @@ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
100
 
101
  def fetch_pdb(pdb_id):
102
  pdb_path = fetch_structure(pdb_id)
103
- if not pdb_path:
104
- return None
105
  _, ext = os.path.splitext(pdb_path)
106
  if ext == '.cif':
107
  pdb_path = convert_cif_to_pdb(pdb_path)
@@ -111,11 +96,9 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
111
  """
112
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
113
  """
114
- # Read the original PDB file
115
  parser = PDBParser(QUIET=True)
116
  structure = parser.get_structure('protein', input_pdb)
117
 
118
- # Prepare a new structure with only the specified chain and selected residues
119
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
120
 
121
  # Create scores dictionary for easy lookup
@@ -148,8 +131,69 @@ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: lis
148
 
149
  return output_pdb
150
 
151
- def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
 
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Determine if input is a PDB ID or file path
154
  if pdb_id_or_file.endswith('.pdb'):
155
  pdb_path = pdb_id_or_file
@@ -158,51 +202,37 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
158
  pdb_id = pdb_id_or_file
159
  pdb_path = fetch_pdb(pdb_id)
160
 
161
- if not pdb_path:
162
- return "Failed to fetch PDB file", None, None
163
-
164
  # Determine the file format and choose the appropriate parser
165
  _, ext = os.path.splitext(pdb_path)
166
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
167
 
168
- try:
169
- # Parse the structure file
170
- structure = parser.get_structure('protein', pdb_path)
171
- except Exception as e:
172
- return f"Error parsing structure file: {e}", None, None
173
 
174
  # Extract the specified chain
175
- try:
176
- chain = structure[0][segment]
177
- except KeyError:
178
- return "Invalid Chain ID", None, None
179
 
180
  protein_residues = [res for res in chain if is_aa(res)]
181
  sequence = "".join(seq1(res.resname) for res in protein_residues)
182
  sequence_id = [res.id[1] for res in protein_residues]
183
 
184
- visualized_sequence = "".join(seq1(res.resname) for res in protein_residues)
185
- if sequence != visualized_sequence:
186
- raise ValueError("The visualized sequence does not match the prediction sequence")
187
-
188
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
189
  with torch.no_grad():
190
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
191
 
192
  # Calculate scores and normalize them
193
- scores = expit(outputs[:, 1] - outputs[:, 0])
194
- normalized_scores = normalize_scores(scores)
195
 
196
  # Choose which scores to use based on score_type
197
- display_scores = normalized_scores if score_type == 'normalized' else scores
198
 
199
  # Zip residues with scores to track the residue ID and score
200
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
201
 
202
  # Also save both score types for later use
203
- raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, scores)]
204
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
205
-
206
 
207
  # Define the score brackets
208
  score_brackets = {
@@ -223,79 +253,35 @@ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
223
  residues_by_bracket[bracket].append(resi)
224
  break
225
 
226
- # Preparing the result string
227
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
228
- result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
229
- result_str += "Residues by Score Brackets:\n\n"
230
 
231
- # Add residues for each bracket
232
- for bracket, residues in residues_by_bracket.items():
233
- result_str += f"Bracket {bracket}:\n"
234
- result_str += "Columns: Residue Name, Residue Number, One-letter Code, Normalized Score\n"
235
- result_str += "\n".join([
236
- f"{res.resname} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
237
- for i, res in enumerate(protein_residues) if res.id[1] in residues
238
- ])
239
- result_str += "\n\n"
240
-
241
  # Create chain-specific PDB with scores in B-factor
242
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
243
 
244
  # Molecule visualization with updated script with color mapping
245
- mol_vis = molecule(pdb_path, residue_scores, segment)#, color_map)
246
-
247
- # Improved PyMOL command suggestions
248
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
249
- pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\n\n"
250
-
251
- pymol_commands += f"""
252
- # PyMOL Visualization Commands
253
- fetch {pdb_id}, protein
254
- hide everything, all
255
- show cartoon, chain {segment}
256
- color white, chain {segment}
257
- """
258
-
259
- # Define colors for each score bracket
260
- bracket_colors = {
261
- "0.0-0.2": "white",
262
- "0.2-0.4": "lightorange",
263
- "0.4-0.6": "orange",
264
- "0.6-0.8": "red",
265
- "0.8-1.0": "firebrick"
266
- }
267
 
268
- # Add PyMOL commands for each score bracket
269
- for bracket, residues in residues_by_bracket.items():
270
- if residues: # Only add commands if there are residues in this bracket
271
- color = bracket_colors[bracket]
272
- resi_list = '+'.join(map(str, residues))
273
- pymol_commands += f"""
274
- select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
275
- show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
276
- color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
277
- """
278
- # Create prediction and scored PDB files
279
- prediction_file = f"{pdb_id}_binding_site_residues.txt"
280
  with open(prediction_file, "w") as f:
281
  f.write(result_str)
282
 
283
- return pymol_commands, mol_vis, [prediction_file, scored_pdb],raw_residue_scores,norm_residue_scores
284
-
 
 
285
 
286
  def molecule(input_pdb, residue_scores=None, segment='A'):
287
- # Check if the file exists
288
- if not os.path.isfile(input_pdb):
289
- return f"<p>Error: PDB file not found at {input_pdb}</p>"
290
-
291
- try:
292
- # Read PDB file content
293
- mol = read_mol(input_pdb)
294
- except Exception as e:
295
- return f"<p>Error reading PDB file: {str(e)}</p>"
296
- # More granular scoring for visualization
297
- #mol = read_mol(input_pdb) # Read PDB file content
298
-
299
  # Prepare high-scoring residues script if scores are provided
300
  high_score_script = ""
301
  if residue_scores is not None:
@@ -491,9 +477,9 @@ with gr.Blocks(css="""
491
  Score dependent colorcoding:
492
  - 0.0-0.2: white
493
  - 0.2–0.4: light orange
494
- - 0.4–0.6: orange
495
- - 0.6–0.8: red
496
- - 0.8–1.0: firebrick
497
  """)
498
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
499
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
@@ -504,6 +490,7 @@ with gr.Blocks(css="""
504
  norm_scores_state = gr.State(None)
505
  last_pdb_path = gr.State(None)
506
  last_segment = gr.State(None)
 
507
 
508
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
509
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
@@ -511,12 +498,10 @@ with gr.Blocks(css="""
511
  # First get the actual PDB file path
512
  if mode == "PDB ID":
513
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
514
- if not pdb_path:
515
- return "Failed to fetch PDB file", None, None, None, None, None, None
516
 
517
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
518
  # Store the actual file path, not just the PDB ID
519
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
520
  elif mode == "Upload File":
521
  _, ext = os.path.splitext(pdb_file.name)
522
  file_path = os.path.join('./', f"{_}{ext}")
@@ -525,24 +510,70 @@ with gr.Blocks(css="""
525
  else:
526
  pdb_path = file_path
527
 
528
- pymol_cmd, mol_vis, files, raw_scores, norm_scores = process_pdb(pdb_path, chain_id, selected_score_type)
529
- return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id
530
- else:
531
- return "Error: Invalid mode selected", None, None, None, None, None, None
532
-
533
- def update_visualization(score_type_val, raw_scores, norm_scores, pdb_path, segment):
534
- if raw_scores is None or norm_scores is None or pdb_path is None or segment is None:
535
- return None
536
-
537
- # Verify the file exists
538
- if not os.path.exists(pdb_path):
539
- return f"Error: File not found at {pdb_path}"
540
 
541
  # Choose scores based on radio button selection
542
- selected_scores = norm_scores if score_type_val == "Normalized Scores" else raw_scores
 
543
 
544
  # Generate visualization with selected scores
545
- return molecule(pdb_path, selected_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  def fetch_interface(mode, pdb_id, pdb_file):
548
  if mode == "PDB ID":
@@ -555,8 +586,6 @@ with gr.Blocks(css="""
555
  else:
556
  pdb_path= file_path
557
  return pdb_path
558
- else:
559
- return "Error: Invalid mode selected"
560
 
561
  def toggle_mode(selected_mode):
562
  if selected_mode == "PDB ID":
@@ -574,14 +603,14 @@ with gr.Blocks(css="""
574
  process_interface,
575
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
576
  outputs=[predictions_output, molecule_output, download_output,
577
- raw_scores_state, norm_scores_state, last_pdb_path, last_segment]
578
  )
579
 
580
- # Update visualization when score type changes
581
  score_type.change(
582
- update_visualization,
583
- inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment],
584
- outputs=[molecule_output]
585
  )
586
 
587
  visualize_btn.click(
@@ -595,9 +624,10 @@ with gr.Blocks(css="""
595
  examples=[
596
  ["7RPZ", "A"],
597
  ["2IWI", "B"],
598
- ["7LCJ", "R"]
 
599
  ],
600
  inputs=[pdb_input, segment_input],
601
  outputs=[predictions_output, molecule_output, download_output]
602
  )
603
- demo.launch(share=True)
 
27
 
28
  from scipy.special import expit
29
 
 
30
  # Load model and move to device
31
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
32
  #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
33
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
34
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full'
35
  max_length = 1500
36
  model, tokenizer = load_model(checkpoint, max_length)
37
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
45
 
46
  def read_mol(pdb_path):
47
  """Read PDB file and return its content as a string"""
48
+ with open(pdb_path, 'r') as f:
49
+ return f.read()
50
+
51
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
 
 
 
 
 
 
 
52
  """
53
  Fetch the structure file for a given PDB ID. Prioritizes CIF files.
54
  If a structure file already exists locally, it uses that.
55
  """
56
  file_path = download_structure(pdb_id, output_dir)
57
+ return file_path
 
 
 
58
 
59
+ def download_structure(pdb_id: str, output_dir: str) -> str:
60
  """
61
  Attempt to download the structure file in CIF or PDB format.
62
+ Returns the path to the downloaded file.
63
  """
64
  for ext in ['.cif', '.pdb']:
65
  file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
66
  if os.path.exists(file_path):
67
  return file_path
68
  url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
69
+ response = requests.get(url, timeout=10)
70
+ if response.status_code == 200:
71
+ with open(file_path, 'wb') as f:
72
+ f.write(response.content)
73
+ return file_path
 
 
 
74
  return None
75
 
76
  def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
 
87
 
88
  def fetch_pdb(pdb_id):
89
  pdb_path = fetch_structure(pdb_id)
 
 
90
  _, ext = os.path.splitext(pdb_path)
91
  if ext == '.cif':
92
  pdb_path = convert_cif_to_pdb(pdb_path)
 
96
  """
97
  Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
98
  """
 
99
  parser = PDBParser(QUIET=True)
100
  structure = parser.get_structure('protein', input_pdb)
101
 
 
102
  output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
103
 
104
  # Create scores dictionary for easy lookup
 
131
 
132
  return output_pdb
133
 
134
+ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
135
+ """Generate PyMOL commands based on score type"""
136
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
137
 
138
+ pymol_commands += f"""
139
+ # PyMOL Visualization Commands
140
+ fetch {pdb_id}, protein
141
+ hide everything, all
142
+ show cartoon, chain {segment}
143
+ color white, chain {segment}
144
+ """
145
+
146
+ # Define colors for each score bracket
147
+ bracket_colors = {
148
+ "0.0-0.2": "white",
149
+ "0.2-0.4": "lightorange",
150
+ "0.4-0.6": "yelloworange",
151
+ "0.6-0.8": "orange",
152
+ "0.8-1.0": "red"
153
+ }
154
+
155
+ # Add PyMOL commands for each score bracket
156
+ for bracket, residues in residues_by_bracket.items():
157
+ if residues: # Only add commands if there are residues in this bracket
158
+ color = bracket_colors[bracket]
159
+ resi_list = '+'.join(map(str, residues))
160
+ pymol_commands += f"""
161
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
162
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
163
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
164
+ """
165
+ return pymol_commands
166
+
167
+ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type):
168
+ """Generate results text based on score type"""
169
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
170
+ result_str += "Residues by Score Brackets:\n\n"
171
+
172
+ # Add residues for each bracket
173
+ for bracket, residues in residues_by_bracket.items():
174
+ result_str += f"Bracket {bracket}:\n"
175
+ result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n"
176
+ result_str += "\n".join([
177
+ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
178
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
179
+ ])
180
+ result_str += "\n\n"
181
+
182
+ return result_str
183
+
184
+ def predict_util(sequence):
185
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
186
+ with torch.no_grad():
187
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
188
+
189
+ # Calculate scores and normalize them
190
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
191
+ normalized_scores = normalize_scores(raw_scores)
192
+
193
+ return raw_scores,normalized_scores
194
+
195
+
196
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
197
  # Determine if input is a PDB ID or file path
198
  if pdb_id_or_file.endswith('.pdb'):
199
  pdb_path = pdb_id_or_file
 
202
  pdb_id = pdb_id_or_file
203
  pdb_path = fetch_pdb(pdb_id)
204
 
 
 
 
205
  # Determine the file format and choose the appropriate parser
206
  _, ext = os.path.splitext(pdb_path)
207
  parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
208
 
209
+ # Parse the structure file
210
+ structure = parser.get_structure('protein', pdb_path)
 
 
 
211
 
212
  # Extract the specified chain
213
+ chain = structure[0][segment]
 
 
 
214
 
215
  protein_residues = [res for res in chain if is_aa(res)]
216
  sequence = "".join(seq1(res.resname) for res in protein_residues)
217
  sequence_id = [res.id[1] for res in protein_residues]
218
 
 
 
 
 
219
  input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
220
  with torch.no_grad():
221
  outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
222
 
223
  # Calculate scores and normalize them
224
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
225
+ normalized_scores = normalize_scores(raw_scores)
226
 
227
  # Choose which scores to use based on score_type
228
+ display_scores = normalized_scores if score_type == 'normalized' else raw_scores
229
 
230
  # Zip residues with scores to track the residue ID and score
231
  residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
232
 
233
  # Also save both score types for later use
234
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)]
235
  norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
 
236
 
237
  # Define the score brackets
238
  score_brackets = {
 
253
  residues_by_bracket[bracket].append(resi)
254
  break
255
 
256
+ # Generate timestamp
257
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
 
258
 
259
+ # Generate result text and PyMOL commands based on score type
260
+ display_score_type = "Normalized" if score_type == 'normalized' else "Raw"
261
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
262
+ display_scores, current_time, display_score_type)
263
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
264
+
 
 
 
 
265
  # Create chain-specific PDB with scores in B-factor
266
  scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
267
 
268
  # Molecule visualization with updated script with color mapping
269
+ mol_vis = molecule(pdb_path, residue_scores, segment)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ # Create prediction file
272
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
 
 
 
 
 
 
 
 
 
 
273
  with open(prediction_file, "w") as f:
274
  f.write(result_str)
275
 
276
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
277
+ os.rename(scored_pdb, scored_pdb_name)
278
+
279
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
280
 
281
  def molecule(input_pdb, residue_scores=None, segment='A'):
282
+ # Read PDB file content
283
+ mol = read_mol(input_pdb)
284
+
 
 
 
 
 
 
 
 
 
285
  # Prepare high-scoring residues script if scores are provided
286
  high_score_script = ""
287
  if residue_scores is not None:
 
477
  Score dependent colorcoding:
478
  - 0.0-0.2: white
479
  - 0.2–0.4: light orange
480
+ - 0.4–0.6: yellow orange
481
+ - 0.6–0.8: orange
482
+ - 0.8–1.0: red
483
  """)
484
  predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
485
  gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
 
490
  norm_scores_state = gr.State(None)
491
  last_pdb_path = gr.State(None)
492
  last_segment = gr.State(None)
493
+ last_pdb_id = gr.State(None)
494
 
495
  def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
496
  selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
 
498
  # First get the actual PDB file path
499
  if mode == "PDB ID":
500
  pdb_path = fetch_pdb(pdb_id) # Get the actual file path
 
 
501
 
502
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
503
  # Store the actual file path, not just the PDB ID
504
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
505
  elif mode == "Upload File":
506
  _, ext = os.path.splitext(pdb_file.name)
507
  file_path = os.path.join('./', f"{_}{ext}")
 
510
  else:
511
  pdb_path = file_path
512
 
513
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
514
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
515
+
516
+ def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
517
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
518
+ return None, None, None
 
 
 
 
 
 
519
 
520
  # Choose scores based on radio button selection
521
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
522
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
523
 
524
  # Generate visualization with selected scores
525
+ mol_vis = molecule(pdb_path, selected_scores, segment)
526
+
527
+ # Generate PyMOL commands and downloadable files
528
+ # Get structure for residue info
529
+ _, ext = os.path.splitext(pdb_path)
530
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
531
+ structure = parser.get_structure('protein', pdb_path)
532
+ chain = structure[0][segment]
533
+ protein_residues = [res for res in chain if is_aa(res)]
534
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
535
+
536
+ # Define score brackets
537
+ score_brackets = {
538
+ "0.0-0.2": (0.0, 0.2),
539
+ "0.2-0.4": (0.2, 0.4),
540
+ "0.4-0.6": (0.4, 0.6),
541
+ "0.6-0.8": (0.6, 0.8),
542
+ "0.8-1.0": (0.8, 1.0)
543
+ }
544
+
545
+ # Initialize a dictionary to store residues by bracket
546
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
547
+
548
+ # Categorize residues into brackets
549
+ for resi, score in selected_scores:
550
+ for bracket, (lower, upper) in score_brackets.items():
551
+ if lower <= score < upper:
552
+ residues_by_bracket[bracket].append(resi)
553
+ break
554
+
555
+ # Generate timestamp
556
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
557
+
558
+ # Generate result text and PyMOL commands based on score type
559
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
560
+ scores_array = [score for _, score in selected_scores]
561
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
562
+ scores_array, current_time, display_score_type)
563
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
564
+
565
+ # Create chain-specific PDB with scores in B-factor
566
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
567
+
568
+ # Create prediction file
569
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
570
+ with open(prediction_file, "w") as f:
571
+ f.write(result_str)
572
+
573
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
574
+ os.rename(scored_pdb, scored_pdb_name)
575
+
576
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
577
 
578
  def fetch_interface(mode, pdb_id, pdb_file):
579
  if mode == "PDB ID":
 
586
  else:
587
  pdb_path= file_path
588
  return pdb_path
 
 
589
 
590
  def toggle_mode(selected_mode):
591
  if selected_mode == "PDB ID":
 
603
  process_interface,
604
  inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
605
  outputs=[predictions_output, molecule_output, download_output,
606
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id]
607
  )
608
 
609
+ # Update visualization, PyMOL commands, and files when score type changes
610
  score_type.change(
611
+ update_visualization_and_files,
612
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id],
613
+ outputs=[molecule_output, predictions_output, download_output]
614
  )
615
 
616
  visualize_btn.click(
 
624
  examples=[
625
  ["7RPZ", "A"],
626
  ["2IWI", "B"],
627
+ ["7LCJ", "R"],
628
+ ["4OBE", "A"]
629
  ],
630
  inputs=[pdb_input, segment_input],
631
  outputs=[predictions_output, molecule_output, download_output]
632
  )
633
+ demo.launch(share=True)