ThorbenF commited on
Commit
a6b7cf0
·
1 Parent(s): 0c18e19
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +75 -86
  2. app.py +75 -86
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -42,19 +42,22 @@ import py3Dmol
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
- checkpoint='ThorbenF/prot_t5_xl_uniref50'
46
- max_length=1500
 
47
 
48
- model, tokenizer = load_model(checkpoint,max_length)
 
49
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
  model.to(device)
51
  model.eval()
52
 
53
- def create_dataset(tokenizer,seqs,labels,checkpoint):
54
 
55
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
56
  dataset = Dataset.from_dict(tokenized)
57
 
 
58
  if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
59
  labels = [l[:max_length-2] for l in labels]
60
  else:
@@ -63,128 +66,115 @@ def create_dataset(tokenizer,seqs,labels,checkpoint):
63
  dataset = dataset.add_column("labels", labels)
64
 
65
  return dataset
66
-
67
  def convert_predictions(input_logits):
 
68
  all_probs = []
69
  for logits in input_logits:
70
  logits = logits.reshape(-1, 2)
71
-
72
- # Mask out irrelevant regions
73
- # Compute probabilities for class 1
74
  probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
75
-
76
  all_probs.append(probabilities_class1)
77
 
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
81
- min_score = np.min(scores)
82
- max_score = np.max(scores)
83
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
84
 
 
 
 
 
85
  def predict_protein_sequence(test_one_letter_sequence):
86
- dummy_labels=[np.zeros(len(test_one_letter_sequence))]
87
- # Replace uncommon amino acids with "X"
88
- test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X")
89
 
90
- # Add spaces between each amino acid for ProtT5 and ProstT5 models
 
 
 
 
 
91
  if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
92
  test_one_letter_sequence = " ".join(test_one_letter_sequence)
93
 
94
- # Add <AA2fold> for ProstT5 model input format
95
  if "ProstT5" in checkpoint:
96
  test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
97
-
98
- test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)
99
 
100
- if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
101
- data_collator = DataCollatorForTokenClassificationESM(tokenizer)
102
- else:
103
- data_collator = DataCollatorForTokenClassification(tokenizer)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
106
 
 
107
  for batch in test_loader:
108
  input_ids = batch['input_ids'].to(device)
109
  attention_mask = batch['attention_mask'].to(device)
110
- labels = batch['labels'] # Ensure to get labels from batch
111
-
112
- outputs = model(input_ids, attention_mask=attention_mask)
113
- logits = outputs.logits.detach().cpu().numpy()
114
 
115
- logits = logits[:, :-1] #remove for prot_t5 the last element, because it is a special token
116
- logits=convert_predictions(logits)
 
117
 
 
118
  normalized_scores = normalize_scores(logits)
119
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
120
 
121
  result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
122
-
123
 
124
  return result_str
125
 
126
-
127
- #interface = gr.Interface(
128
- # fn=predict_protein_sequence,
129
- # inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."),
130
- # outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs
131
- # title="Protein sequence - Binding site prediction",
132
- # description="Enter a protein sequence to predict its possible binding sites.",
133
- #)
134
-
135
- # Launch the app
136
- #interface.launch()
137
-
138
-
139
  def fetch_and_display_pdb(pdb_id):
140
- # Construct the PDB URL
141
- pdb_url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
142
-
143
- # Try fetching the PDB file
144
- response = requests.get(pdb_url)
145
- if response.status_code != 200:
146
- return "Failed to fetch PDB file"
147
-
148
- # Get the structure content as text
149
- structure_text = response.text
150
-
151
- # Create the HTML content with embedded 3Dmol.js
152
- html_content = f"""
153
- <html>
154
- <head>
155
- <script src="https://3Dmol.js.org/build/3Dmol-min.js"></script>
156
- <style>
157
- #viewer {{
158
- width: 800px;
159
- height: 600px;
160
- }}
161
- </style>
162
- </head>
163
- <body>
164
- <div id="viewer"></div>
165
- <script>
166
- const viewer = $3Dmol.createViewer("viewer", {{ backgroundColor: "white" }});
167
- viewer.addModel(`{structure_text}`, "pdb");
168
- viewer.setStyle({}, {{ cartoon: {{ color: "spectrum" }} }});
169
- viewer.zoomTo();
170
- viewer.render();
171
- </script>
172
- </body>
173
- </html>
174
- """
175
- return html_content
176
-
177
-
178
- # Define the Gradio interface
179
  def gradio_interface(sequence, pdb_id):
180
- # Call the prediction function
 
181
  binding_site_predictions = predict_protein_sequence(sequence)
182
 
183
- # Call the PDB structure visualization function
184
  pdb_structure_html = fetch_and_display_pdb(pdb_id)
185
 
186
  return binding_site_predictions, pdb_structure_html
187
 
 
188
  interface = gr.Interface(
189
  fn=gradio_interface,
190
  inputs=[
@@ -193,11 +183,10 @@ interface = gr.Interface(
193
  ],
194
  outputs=[
195
  gr.Textbox(label="Binding Site Predictions"),
196
- gr.HTML(label="3Dmol Viewer") # HTML output to render the 3Dmol viewer
197
  ],
198
  title="Protein Binding Site Prediction and 3D Structure Viewer",
199
  description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
200
  )
201
 
202
- # Launch the Gradio app
203
  interface.launch()
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
+ # Configuration
46
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
+ max_length = 1500
48
 
49
+ # Load model and move to device
50
+ model, tokenizer = load_model(checkpoint, max_length)
51
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
  model.to(device)
53
  model.eval()
54
 
55
+ def create_dataset(tokenizer, seqs, labels, checkpoint):
56
 
57
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
58
  dataset = Dataset.from_dict(tokenized)
59
 
60
+ # Adjust labels based on checkpoint
61
  if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
62
  labels = [l[:max_length-2] for l in labels]
63
  else:
 
66
  dataset = dataset.add_column("labels", labels)
67
 
68
  return dataset
69
+
70
  def convert_predictions(input_logits):
71
+
72
  all_probs = []
73
  for logits in input_logits:
74
  logits = logits.reshape(-1, 2)
 
 
 
75
  probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
 
76
  all_probs.append(probabilities_class1)
77
 
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
 
 
 
81
 
82
+ min_score = np.min(scores)
83
+ max_score = np.max(scores)
84
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
85
+
86
  def predict_protein_sequence(test_one_letter_sequence):
 
 
 
87
 
88
+ # Sanitize input sequence
89
+ test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
90
+ .replace("B", "X").replace("U", "X") \
91
+ .replace("Z", "X").replace("J", "X")
92
+
93
+ # Prepare sequence for different model types
94
  if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
95
  test_one_letter_sequence = " ".join(test_one_letter_sequence)
96
 
 
97
  if "ProstT5" in checkpoint:
98
  test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
 
 
99
 
100
+ # Create dummy labels
101
+ dummy_labels = [np.zeros(len(test_one_letter_sequence))]
 
 
102
 
103
+ # Create dataset
104
+ test_dataset = create_dataset(tokenizer,
105
+ [test_one_letter_sequence],
106
+ dummy_labels,
107
+ checkpoint)
108
+
109
+ # Select appropriate data collator
110
+ data_collator = (DataCollatorForTokenClassification(tokenizer)
111
+ if "esm" not in checkpoint and "ProstT5" not in checkpoint
112
+ else DataCollatorForTokenClassification(tokenizer))
113
+
114
+ # Create data loader
115
  test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
116
 
117
+ # Predict
118
  for batch in test_loader:
119
  input_ids = batch['input_ids'].to(device)
120
  attention_mask = batch['attention_mask'].to(device)
121
+
122
+ with torch.no_grad():
123
+ outputs = model(input_ids, attention_mask=attention_mask)
124
+ logits = outputs.logits.detach().cpu().numpy()
125
 
126
+ # Process logits
127
+ logits = logits[:, :-1] # Remove last element for prot_t5
128
+ logits = convert_predictions(logits)
129
 
130
+ # Normalize and format results
131
  normalized_scores = normalize_scores(logits)
132
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
133
 
134
  result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
 
135
 
136
  return result_str
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def fetch_and_display_pdb(pdb_id):
139
+
140
+ try:
141
+ # Fetch the PDB structure from RCSB
142
+ pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
143
+ response = requests.get(pdb_url)
144
+
145
+ if response.status_code != 200:
146
+ return "Failed to load PDB structure. Please check the PDB ID."
147
+
148
+ pdb_structure = response.text
149
+
150
+ # Prepare the 3D molecular visualization
151
+ visualization = f"""
152
+ <div id="container" style="width: 100%; height: 400px; position: relative;"></div>
153
+ <script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
154
+ <script>
155
+ let viewer = $3Dmol.createViewer(document.getElementById("container"));
156
+ viewer.addModel(`{pdb_structure}`, "pdb");
157
+ viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}});
158
+ viewer.zoomTo();
159
+ viewer.render();
160
+ </script>
161
+ """
162
+ return visualization
163
+
164
+ except Exception as e:
165
+ return f"Error visualizing PDB: {str(e)}"
166
+
 
 
 
 
 
 
 
 
 
 
 
167
  def gradio_interface(sequence, pdb_id):
168
+
169
+ # Predict binding sites
170
  binding_site_predictions = predict_protein_sequence(sequence)
171
 
172
+ # Fetch and visualize PDB structure
173
  pdb_structure_html = fetch_and_display_pdb(pdb_id)
174
 
175
  return binding_site_predictions, pdb_structure_html
176
 
177
+ # Create Gradio interface
178
  interface = gr.Interface(
179
  fn=gradio_interface,
180
  inputs=[
 
183
  ],
184
  outputs=[
185
  gr.Textbox(label="Binding Site Predictions"),
186
+ gr.HTML(label="3D Molecular Viewer")
187
  ],
188
  title="Protein Binding Site Prediction and 3D Structure Viewer",
189
  description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
190
  )
191
 
 
192
  interface.launch()
app.py CHANGED
@@ -42,19 +42,22 @@ import py3Dmol
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
- checkpoint='ThorbenF/prot_t5_xl_uniref50'
46
- max_length=1500
 
47
 
48
- model, tokenizer = load_model(checkpoint,max_length)
 
49
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
  model.to(device)
51
  model.eval()
52
 
53
- def create_dataset(tokenizer,seqs,labels,checkpoint):
54
 
55
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
56
  dataset = Dataset.from_dict(tokenized)
57
 
 
58
  if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
59
  labels = [l[:max_length-2] for l in labels]
60
  else:
@@ -63,128 +66,115 @@ def create_dataset(tokenizer,seqs,labels,checkpoint):
63
  dataset = dataset.add_column("labels", labels)
64
 
65
  return dataset
66
-
67
  def convert_predictions(input_logits):
 
68
  all_probs = []
69
  for logits in input_logits:
70
  logits = logits.reshape(-1, 2)
71
-
72
- # Mask out irrelevant regions
73
- # Compute probabilities for class 1
74
  probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
75
-
76
  all_probs.append(probabilities_class1)
77
 
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
81
- min_score = np.min(scores)
82
- max_score = np.max(scores)
83
- return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
84
 
 
 
 
 
85
  def predict_protein_sequence(test_one_letter_sequence):
86
- dummy_labels=[np.zeros(len(test_one_letter_sequence))]
87
- # Replace uncommon amino acids with "X"
88
- test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X")
89
 
90
- # Add spaces between each amino acid for ProtT5 and ProstT5 models
 
 
 
 
 
91
  if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
92
  test_one_letter_sequence = " ".join(test_one_letter_sequence)
93
 
94
- # Add <AA2fold> for ProstT5 model input format
95
  if "ProstT5" in checkpoint:
96
  test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
97
-
98
- test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)
99
 
100
- if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
101
- data_collator = DataCollatorForTokenClassificationESM(tokenizer)
102
- else:
103
- data_collator = DataCollatorForTokenClassification(tokenizer)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
106
 
 
107
  for batch in test_loader:
108
  input_ids = batch['input_ids'].to(device)
109
  attention_mask = batch['attention_mask'].to(device)
110
- labels = batch['labels'] # Ensure to get labels from batch
111
-
112
- outputs = model(input_ids, attention_mask=attention_mask)
113
- logits = outputs.logits.detach().cpu().numpy()
114
 
115
- logits = logits[:, :-1] #remove for prot_t5 the last element, because it is a special token
116
- logits=convert_predictions(logits)
 
117
 
 
118
  normalized_scores = normalize_scores(logits)
119
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
120
 
121
  result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
122
-
123
 
124
  return result_str
125
 
126
-
127
- #interface = gr.Interface(
128
- # fn=predict_protein_sequence,
129
- # inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."),
130
- # outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs
131
- # title="Protein sequence - Binding site prediction",
132
- # description="Enter a protein sequence to predict its possible binding sites.",
133
- #)
134
-
135
- # Launch the app
136
- #interface.launch()
137
-
138
-
139
  def fetch_and_display_pdb(pdb_id):
140
- # Construct the PDB URL
141
- pdb_url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
142
-
143
- # Try fetching the PDB file
144
- response = requests.get(pdb_url)
145
- if response.status_code != 200:
146
- return "Failed to fetch PDB file"
147
-
148
- # Get the structure content as text
149
- structure_text = response.text
150
-
151
- # Create the HTML content with embedded 3Dmol.js
152
- html_content = f"""
153
- <html>
154
- <head>
155
- <script src="https://3Dmol.js.org/build/3Dmol-min.js"></script>
156
- <style>
157
- #viewer {{
158
- width: 800px;
159
- height: 600px;
160
- }}
161
- </style>
162
- </head>
163
- <body>
164
- <div id="viewer"></div>
165
- <script>
166
- const viewer = $3Dmol.createViewer("viewer", {{ backgroundColor: "white" }});
167
- viewer.addModel(`{structure_text}`, "pdb");
168
- viewer.setStyle({}, {{ cartoon: {{ color: "spectrum" }} }});
169
- viewer.zoomTo();
170
- viewer.render();
171
- </script>
172
- </body>
173
- </html>
174
- """
175
- return html_content
176
-
177
-
178
- # Define the Gradio interface
179
  def gradio_interface(sequence, pdb_id):
180
- # Call the prediction function
 
181
  binding_site_predictions = predict_protein_sequence(sequence)
182
 
183
- # Call the PDB structure visualization function
184
  pdb_structure_html = fetch_and_display_pdb(pdb_id)
185
 
186
  return binding_site_predictions, pdb_structure_html
187
 
 
188
  interface = gr.Interface(
189
  fn=gradio_interface,
190
  inputs=[
@@ -193,11 +183,10 @@ interface = gr.Interface(
193
  ],
194
  outputs=[
195
  gr.Textbox(label="Binding Site Predictions"),
196
- gr.HTML(label="3Dmol Viewer") # HTML output to render the 3Dmol viewer
197
  ],
198
  title="Protein Binding Site Prediction and 3D Structure Viewer",
199
  description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
200
  )
201
 
202
- # Launch the Gradio app
203
  interface.launch()
 
42
  #import peft
43
  #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
44
 
45
+ # Configuration
46
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
47
+ max_length = 1500
48
 
49
+ # Load model and move to device
50
+ model, tokenizer = load_model(checkpoint, max_length)
51
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
  model.to(device)
53
  model.eval()
54
 
55
+ def create_dataset(tokenizer, seqs, labels, checkpoint):
56
 
57
  tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
58
  dataset = Dataset.from_dict(tokenized)
59
 
60
+ # Adjust labels based on checkpoint
61
  if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
62
  labels = [l[:max_length-2] for l in labels]
63
  else:
 
66
  dataset = dataset.add_column("labels", labels)
67
 
68
  return dataset
69
+
70
  def convert_predictions(input_logits):
71
+
72
  all_probs = []
73
  for logits in input_logits:
74
  logits = logits.reshape(-1, 2)
 
 
 
75
  probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
 
76
  all_probs.append(probabilities_class1)
77
 
78
  return np.concatenate(all_probs)
79
 
80
  def normalize_scores(scores):
 
 
 
81
 
82
+ min_score = np.min(scores)
83
+ max_score = np.max(scores)
84
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
85
+
86
  def predict_protein_sequence(test_one_letter_sequence):
 
 
 
87
 
88
+ # Sanitize input sequence
89
+ test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
90
+ .replace("B", "X").replace("U", "X") \
91
+ .replace("Z", "X").replace("J", "X")
92
+
93
+ # Prepare sequence for different model types
94
  if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
95
  test_one_letter_sequence = " ".join(test_one_letter_sequence)
96
 
 
97
  if "ProstT5" in checkpoint:
98
  test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
 
 
99
 
100
+ # Create dummy labels
101
+ dummy_labels = [np.zeros(len(test_one_letter_sequence))]
 
 
102
 
103
+ # Create dataset
104
+ test_dataset = create_dataset(tokenizer,
105
+ [test_one_letter_sequence],
106
+ dummy_labels,
107
+ checkpoint)
108
+
109
+ # Select appropriate data collator
110
+ data_collator = (DataCollatorForTokenClassification(tokenizer)
111
+ if "esm" not in checkpoint and "ProstT5" not in checkpoint
112
+ else DataCollatorForTokenClassification(tokenizer))
113
+
114
+ # Create data loader
115
  test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
116
 
117
+ # Predict
118
  for batch in test_loader:
119
  input_ids = batch['input_ids'].to(device)
120
  attention_mask = batch['attention_mask'].to(device)
121
+
122
+ with torch.no_grad():
123
+ outputs = model(input_ids, attention_mask=attention_mask)
124
+ logits = outputs.logits.detach().cpu().numpy()
125
 
126
+ # Process logits
127
+ logits = logits[:, :-1] # Remove last element for prot_t5
128
+ logits = convert_predictions(logits)
129
 
130
+ # Normalize and format results
131
  normalized_scores = normalize_scores(logits)
132
  test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
133
 
134
  result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
 
135
 
136
  return result_str
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def fetch_and_display_pdb(pdb_id):
139
+
140
+ try:
141
+ # Fetch the PDB structure from RCSB
142
+ pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
143
+ response = requests.get(pdb_url)
144
+
145
+ if response.status_code != 200:
146
+ return "Failed to load PDB structure. Please check the PDB ID."
147
+
148
+ pdb_structure = response.text
149
+
150
+ # Prepare the 3D molecular visualization
151
+ visualization = f"""
152
+ <div id="container" style="width: 100%; height: 400px; position: relative;"></div>
153
+ <script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
154
+ <script>
155
+ let viewer = $3Dmol.createViewer(document.getElementById("container"));
156
+ viewer.addModel(`{pdb_structure}`, "pdb");
157
+ viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}});
158
+ viewer.zoomTo();
159
+ viewer.render();
160
+ </script>
161
+ """
162
+ return visualization
163
+
164
+ except Exception as e:
165
+ return f"Error visualizing PDB: {str(e)}"
166
+
 
 
 
 
 
 
 
 
 
 
 
167
  def gradio_interface(sequence, pdb_id):
168
+
169
+ # Predict binding sites
170
  binding_site_predictions = predict_protein_sequence(sequence)
171
 
172
+ # Fetch and visualize PDB structure
173
  pdb_structure_html = fetch_and_display_pdb(pdb_id)
174
 
175
  return binding_site_predictions, pdb_structure_html
176
 
177
+ # Create Gradio interface
178
  interface = gr.Interface(
179
  fn=gradio_interface,
180
  inputs=[
 
183
  ],
184
  outputs=[
185
  gr.Textbox(label="Binding Site Predictions"),
186
+ gr.HTML(label="3D Molecular Viewer")
187
  ],
188
  title="Protein Binding Site Prediction and 3D Structure Viewer",
189
  description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
190
  )
191
 
 
192
  interface.launch()