ThorbenF commited on
Commit
01ff8b6
·
1 Parent(s): a6b7cf0
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -37,7 +37,7 @@ from scipy.special import expit
37
 
38
  import requests
39
 
40
- import py3Dmol
41
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
@@ -46,6 +46,21 @@ import py3Dmol
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Load model and move to device
50
  model, tokenizer = load_model(checkpoint, max_length)
51
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -53,7 +68,6 @@ model.to(device)
53
  model.eval()
54
 
55
  def create_dataset(tokenizer, seqs, labels, checkpoint):
56
-
57
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
58
  dataset = Dataset.from_dict(tokenized)
59
 
@@ -68,7 +82,6 @@ def create_dataset(tokenizer, seqs, labels, checkpoint):
68
  return dataset
69
 
70
  def convert_predictions(input_logits):
71
-
72
  all_probs = []
73
  for logits in input_logits:
74
  logits = logits.reshape(-1, 2)
@@ -78,13 +91,11 @@ def convert_predictions(input_logits):
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
81
-
82
  min_score = np.min(scores)
83
  max_score = np.max(scores)
84
  return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
85
 
86
  def predict_protein_sequence(test_one_letter_sequence):
87
-
88
  # Sanitize input sequence
89
  test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
90
  .replace("B", "X").replace("U", "X") \
@@ -135,58 +146,88 @@ def predict_protein_sequence(test_one_letter_sequence):
135
 
136
  return result_str
137
 
138
- def fetch_and_display_pdb(pdb_id):
139
-
140
  try:
 
 
 
141
  # Fetch the PDB structure from RCSB
142
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
143
- response = requests.get(pdb_url)
144
-
145
- if response.status_code != 200:
146
- return "Failed to load PDB structure. Please check the PDB ID."
147
 
148
- pdb_structure = response.text
 
149
 
150
- # Prepare the 3D molecular visualization
151
- visualization = f"""
152
- <div id="container" style="width: 100%; height: 400px; position: relative;"></div>
153
- <script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
154
- <script>
155
- let viewer = $3Dmol.createViewer(document.getElementById("container"));
156
- viewer.addModel(`{pdb_structure}`, "pdb");
157
- viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}});
158
- viewer.zoomTo();
159
- viewer.render();
160
- </script>
161
- """
162
- return visualization
163
 
164
  except Exception as e:
165
- return f"Error visualizing PDB: {str(e)}"
 
166
 
167
- def gradio_interface(sequence, pdb_id):
168
-
169
  # Predict binding sites
170
  binding_site_predictions = predict_protein_sequence(sequence)
171
 
172
- # Fetch and visualize PDB structure
173
- pdb_structure_html = fetch_and_display_pdb(pdb_id)
174
 
175
- return binding_site_predictions, pdb_structure_html
176
 
177
  # Create Gradio interface
178
- interface = gr.Interface(
179
- fn=gradio_interface,
180
- inputs=[
181
- gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"),
182
- gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization")
183
- ],
184
- outputs=[
185
- gr.Textbox(label="Binding Site Predictions"),
186
- gr.HTML(label="3D Molecular Viewer")
187
- ],
188
- title="Protein Binding Site Prediction and 3D Structure Viewer",
189
- description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
190
- )
191
-
192
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  import requests
39
 
40
+ from gradio_molecule3d import Molecule3D
41
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
 
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
49
+ # Default representations for molecule rendering
50
+ reps = [
51
+ {
52
+ "model": 0,
53
+ "chain": "",
54
+ "resname": "",
55
+ "style": "cartoon",
56
+ "color": "spectrum",
57
+ "residue_range": "",
58
+ "around": 0,
59
+ "byres": False,
60
+ "visible": True
61
+ }
62
+ ]
63
+
64
  # Load model and move to device
65
  model, tokenizer = load_model(checkpoint, max_length)
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
68
  model.eval()
69
 
70
  def create_dataset(tokenizer, seqs, labels, checkpoint):
 
71
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
72
  dataset = Dataset.from_dict(tokenized)
73
 
 
82
  return dataset
83
 
84
  def convert_predictions(input_logits):
 
85
  all_probs = []
86
  for logits in input_logits:
87
  logits = logits.reshape(-1, 2)
 
91
  return np.concatenate(all_probs)
92
 
93
  def normalize_scores(scores):
 
94
  min_score = np.min(scores)
95
  max_score = np.max(scores)
96
  return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
97
 
98
  def predict_protein_sequence(test_one_letter_sequence):
 
99
  # Sanitize input sequence
100
  test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
101
  .replace("B", "X").replace("U", "X") \
 
146
 
147
  return result_str
148
 
149
+ def fetch_pdb(pdb_id):
 
150
  try:
151
+ # Create a directory to store PDB files if it doesn't exist
152
+ os.makedirs('pdb_files', exist_ok=True)
153
+
154
  # Fetch the PDB structure from RCSB
155
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
156
+ pdb_path = f'pdb_files/{pdb_id}.pdb'
 
 
 
157
 
158
+ # Download the file
159
+ response = requests.get(pdb_url)
160
 
161
+ if response.status_code == 200:
162
+ with open(pdb_path, 'wb') as f:
163
+ f.write(response.content)
164
+ return pdb_path
165
+ else:
166
+ return None
 
 
 
 
 
 
 
167
 
168
  except Exception as e:
169
+ print(f"Error fetching PDB: {e}")
170
+ return None
171
 
172
+ def process_input(sequence, pdb_id):
 
173
  # Predict binding sites
174
  binding_site_predictions = predict_protein_sequence(sequence)
175
 
176
+ # Fetch PDB file
177
+ pdb_path = fetch_pdb(pdb_id)
178
 
179
+ return binding_site_predictions, pdb_path
180
 
181
  # Create Gradio interface
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown("# Protein Binding Site Prediction")
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ # Sequence input
188
+ sequence_input = gr.Textbox(
189
+ lines=2,
190
+ placeholder="Enter protein sequence here...",
191
+ label="Protein Sequence"
192
+ )
193
+
194
+ # PDB ID input
195
+ pdb_input = gr.Textbox(
196
+ lines=1,
197
+ placeholder="Enter PDB ID here...",
198
+ label="PDB ID for 3D Visualization"
199
+ )
200
+
201
+ # Predict button
202
+ predict_btn = gr.Button("Predict Binding Sites")
203
+
204
+ with gr.Column():
205
+ # Binding site predictions output
206
+ predictions_output = gr.Textbox(
207
+ label="Binding Site Predictions"
208
+ )
209
+
210
+ # 3D Molecule visualization
211
+ molecule_output = Molecule3D(
212
+ label="Protein Structure",
213
+ reps=reps
214
+ )
215
+
216
+ # Prediction logic
217
+ predict_btn.click(
218
+ process_input,
219
+ inputs=[sequence_input, pdb_input],
220
+ outputs=[predictions_output, molecule_output]
221
+ )
222
+
223
+ # Add some example inputs
224
+ gr.Markdown("## Examples")
225
+ gr.Examples(
226
+ examples=[
227
+ ["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
228
+ ],
229
+ inputs=[sequence_input, pdb_input],
230
+ outputs=[predictions_output, molecule_output]
231
+ )
232
+
233
+ demo.launch()
.ipynb_checkpoints/requirements-checkpoint.txt CHANGED
@@ -9,4 +9,4 @@ scikit-learn>=0.24.0
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
- py3Dmol
 
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
+ gradio_molecule3d
app.py CHANGED
@@ -37,7 +37,7 @@ from scipy.special import expit
37
 
38
  import requests
39
 
40
- import py3Dmol
41
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
@@ -46,6 +46,21 @@ import py3Dmol
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Load model and move to device
50
  model, tokenizer = load_model(checkpoint, max_length)
51
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -53,7 +68,6 @@ model.to(device)
53
  model.eval()
54
 
55
  def create_dataset(tokenizer, seqs, labels, checkpoint):
56
-
57
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
58
  dataset = Dataset.from_dict(tokenized)
59
 
@@ -68,7 +82,6 @@ def create_dataset(tokenizer, seqs, labels, checkpoint):
68
  return dataset
69
 
70
  def convert_predictions(input_logits):
71
-
72
  all_probs = []
73
  for logits in input_logits:
74
  logits = logits.reshape(-1, 2)
@@ -78,13 +91,11 @@ def convert_predictions(input_logits):
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
81
-
82
  min_score = np.min(scores)
83
  max_score = np.max(scores)
84
  return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
85
 
86
  def predict_protein_sequence(test_one_letter_sequence):
87
-
88
  # Sanitize input sequence
89
  test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
90
  .replace("B", "X").replace("U", "X") \
@@ -135,58 +146,88 @@ def predict_protein_sequence(test_one_letter_sequence):
135
 
136
  return result_str
137
 
138
- def fetch_and_display_pdb(pdb_id):
139
-
140
  try:
 
 
 
141
  # Fetch the PDB structure from RCSB
142
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
143
- response = requests.get(pdb_url)
144
-
145
- if response.status_code != 200:
146
- return "Failed to load PDB structure. Please check the PDB ID."
147
 
148
- pdb_structure = response.text
 
149
 
150
- # Prepare the 3D molecular visualization
151
- visualization = f"""
152
- <div id="container" style="width: 100%; height: 400px; position: relative;"></div>
153
- <script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
154
- <script>
155
- let viewer = $3Dmol.createViewer(document.getElementById("container"));
156
- viewer.addModel(`{pdb_structure}`, "pdb");
157
- viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}});
158
- viewer.zoomTo();
159
- viewer.render();
160
- </script>
161
- """
162
- return visualization
163
 
164
  except Exception as e:
165
- return f"Error visualizing PDB: {str(e)}"
 
166
 
167
- def gradio_interface(sequence, pdb_id):
168
-
169
  # Predict binding sites
170
  binding_site_predictions = predict_protein_sequence(sequence)
171
 
172
- # Fetch and visualize PDB structure
173
- pdb_structure_html = fetch_and_display_pdb(pdb_id)
174
 
175
- return binding_site_predictions, pdb_structure_html
176
 
177
  # Create Gradio interface
178
- interface = gr.Interface(
179
- fn=gradio_interface,
180
- inputs=[
181
- gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"),
182
- gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization")
183
- ],
184
- outputs=[
185
- gr.Textbox(label="Binding Site Predictions"),
186
- gr.HTML(label="3D Molecular Viewer")
187
- ],
188
- title="Protein Binding Site Prediction and 3D Structure Viewer",
189
- description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
190
- )
191
-
192
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  import requests
39
 
40
+ from gradio_molecule3d import Molecule3D
41
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
 
46
  checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
  max_length = 1500
48
 
49
+ # Default representations for molecule rendering
50
+ reps = [
51
+ {
52
+ "model": 0,
53
+ "chain": "",
54
+ "resname": "",
55
+ "style": "cartoon",
56
+ "color": "spectrum",
57
+ "residue_range": "",
58
+ "around": 0,
59
+ "byres": False,
60
+ "visible": True
61
+ }
62
+ ]
63
+
64
  # Load model and move to device
65
  model, tokenizer = load_model(checkpoint, max_length)
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
68
  model.eval()
69
 
70
  def create_dataset(tokenizer, seqs, labels, checkpoint):
 
71
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
72
  dataset = Dataset.from_dict(tokenized)
73
 
 
82
  return dataset
83
 
84
  def convert_predictions(input_logits):
 
85
  all_probs = []
86
  for logits in input_logits:
87
  logits = logits.reshape(-1, 2)
 
91
  return np.concatenate(all_probs)
92
 
93
  def normalize_scores(scores):
 
94
  min_score = np.min(scores)
95
  max_score = np.max(scores)
96
  return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
97
 
98
  def predict_protein_sequence(test_one_letter_sequence):
 
99
  # Sanitize input sequence
100
  test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
101
  .replace("B", "X").replace("U", "X") \
 
146
 
147
  return result_str
148
 
149
+ def fetch_pdb(pdb_id):
 
150
  try:
151
+ # Create a directory to store PDB files if it doesn't exist
152
+ os.makedirs('pdb_files', exist_ok=True)
153
+
154
  # Fetch the PDB structure from RCSB
155
  pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
156
+ pdb_path = f'pdb_files/{pdb_id}.pdb'
 
 
 
157
 
158
+ # Download the file
159
+ response = requests.get(pdb_url)
160
 
161
+ if response.status_code == 200:
162
+ with open(pdb_path, 'wb') as f:
163
+ f.write(response.content)
164
+ return pdb_path
165
+ else:
166
+ return None
 
 
 
 
 
 
 
167
 
168
  except Exception as e:
169
+ print(f"Error fetching PDB: {e}")
170
+ return None
171
 
172
+ def process_input(sequence, pdb_id):
 
173
  # Predict binding sites
174
  binding_site_predictions = predict_protein_sequence(sequence)
175
 
176
+ # Fetch PDB file
177
+ pdb_path = fetch_pdb(pdb_id)
178
 
179
+ return binding_site_predictions, pdb_path
180
 
181
  # Create Gradio interface
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown("# Protein Binding Site Prediction")
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ # Sequence input
188
+ sequence_input = gr.Textbox(
189
+ lines=2,
190
+ placeholder="Enter protein sequence here...",
191
+ label="Protein Sequence"
192
+ )
193
+
194
+ # PDB ID input
195
+ pdb_input = gr.Textbox(
196
+ lines=1,
197
+ placeholder="Enter PDB ID here...",
198
+ label="PDB ID for 3D Visualization"
199
+ )
200
+
201
+ # Predict button
202
+ predict_btn = gr.Button("Predict Binding Sites")
203
+
204
+ with gr.Column():
205
+ # Binding site predictions output
206
+ predictions_output = gr.Textbox(
207
+ label="Binding Site Predictions"
208
+ )
209
+
210
+ # 3D Molecule visualization
211
+ molecule_output = Molecule3D(
212
+ label="Protein Structure",
213
+ reps=reps
214
+ )
215
+
216
+ # Prediction logic
217
+ predict_btn.click(
218
+ process_input,
219
+ inputs=[sequence_input, pdb_input],
220
+ outputs=[predictions_output, molecule_output]
221
+ )
222
+
223
+ # Add some example inputs
224
+ gr.Markdown("## Examples")
225
+ gr.Examples(
226
+ examples=[
227
+ ["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
228
+ ],
229
+ inputs=[sequence_input, pdb_input],
230
+ outputs=[predictions_output, molecule_output]
231
+ )
232
+
233
+ demo.launch()
requirements.txt CHANGED
@@ -9,4 +9,4 @@ scikit-learn>=0.24.0
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
- py3Dmol
 
9
  sentencepiece
10
  huggingface_hub>=0.15.0
11
  requests
12
+ gradio_molecule3d