aksell commited on
Commit
9f086ee
·
1 Parent(s): f7e76de

WIP visualize EC number and attention to it

Browse files

Visualizes EC tags as spheres on the protein structure.

hexviz/attention.py CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
 
 
9
  from hexviz.models import (
10
  ModelType,
11
  get_prot_bert,
@@ -98,11 +99,39 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
98
  return cleaned_sequence, None
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  @st.cache
102
  def get_attention(
103
  sequence: str,
104
  model_type: ModelType = ModelType.TAPE_BERT,
105
  remove_special_tokens: bool = True,
 
106
  ):
107
  """
108
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
@@ -122,6 +151,10 @@ def get_attention(
122
 
123
  elif model_type == ModelType.ZymCTRL:
124
  tokenizer, model = get_zymctrl()
 
 
 
 
125
  inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
126
  attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
127
  device
@@ -133,6 +166,12 @@ def get_attention(
133
  )
134
  attentions = outputs.attentions
135
 
 
 
 
 
 
 
136
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
137
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
138
  # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
@@ -202,6 +241,7 @@ def get_attention_pairs(
202
  threshold: int = 0.2,
203
  model_type: ModelType = ModelType.TAPE_BERT,
204
  top_n: int = 2,
 
205
  ):
206
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
207
  if chain_ids:
@@ -213,7 +253,9 @@ def get_attention_pairs(
213
  top_residues = []
214
  for chain in chains:
215
  sequence = get_sequence(chain)
216
- attention = get_attention(sequence=sequence, model_type=model_type)
 
 
217
  attention_unidirectional = unidirectional_avg_filtered(
218
  attention, layer, head, threshold
219
  )
@@ -222,8 +264,19 @@ def get_attention_pairs(
222
  residue_attention = {}
223
  for attn_value, res_1, res_2 in attention_unidirectional:
224
  try:
225
- coord_1 = chain[res_1]["CA"].coord.tolist()
226
- coord_2 = chain[res_2]["CA"].coord.tolist()
 
 
 
 
 
 
 
 
 
 
 
227
  except KeyError:
228
  continue
229
 
@@ -236,7 +289,14 @@ def get_attention_pairs(
236
  )[:top_n]
237
 
238
  for res, attn_sum in top_n_residues:
239
- coord = chain[res]["CA"].coord.tolist()
 
 
 
 
 
 
 
240
  top_residues.append((attn_sum, coord, chain.id, res))
241
 
242
  return attention_pairs, top_residues
 
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
 
9
+ from hexviz.ec_number import ECNumber
10
  from hexviz.models import (
11
  ModelType,
12
  get_prot_bert,
 
99
  return cleaned_sequence, None
100
 
101
 
102
+ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
103
+ tokens = tokenizer.tokenize(sequence)
104
+
105
+ indices_to_remove = [
106
+ i
107
+ for i, token in enumerate(tokens)
108
+ if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
109
+ ]
110
+
111
+ new_attentions = []
112
+
113
+ for attentions in attentions_tuple:
114
+ # Remove rows and columns corresponding to special tokens and periods
115
+ for idx in sorted(indices_to_remove, reverse=True):
116
+ attentions = torch.cat(
117
+ (attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2
118
+ )
119
+ attentions = torch.cat(
120
+ (attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
121
+ )
122
+
123
+ # Append the modified attentions tensor to the new_attentions list
124
+ new_attentions.append(attentions)
125
+
126
+ return new_attentions
127
+
128
+
129
  @st.cache
130
  def get_attention(
131
  sequence: str,
132
  model_type: ModelType = ModelType.TAPE_BERT,
133
  remove_special_tokens: bool = True,
134
+ ec_number: list[ECNumber] = None,
135
  ):
136
  """
137
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
 
151
 
152
  elif model_type == ModelType.ZymCTRL:
153
  tokenizer, model = get_zymctrl()
154
+
155
+ if ec_number:
156
+ sequence = f"{'.'.join([ec.number for ec in ec_number])}<sep><start>{sequence}<end><pad>"
157
+
158
  inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
159
  attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
160
  device
 
166
  )
167
  attentions = outputs.attentions
168
 
169
+ if ec_number:
170
+ # Remove attention to special tokens and periods separating EC number components
171
+ attentions = remove_special_tokens_and_periods(
172
+ attentions, sequence, tokenizer
173
+ )
174
+
175
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
176
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
177
  # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
 
241
  threshold: int = 0.2,
242
  model_type: ModelType = ModelType.TAPE_BERT,
243
  top_n: int = 2,
244
+ ec_number: list[ECNumber] | None = None,
245
  ):
246
  structure = PDBParser().get_structure("pdb", StringIO(pdb_str))
247
  if chain_ids:
 
253
  top_residues = []
254
  for chain in chains:
255
  sequence = get_sequence(chain)
256
+ attention = get_attention(
257
+ sequence=sequence, model_type=model_type, ec_number=ec_number
258
+ )
259
  attention_unidirectional = unidirectional_avg_filtered(
260
  attention, layer, head, threshold
261
  )
 
264
  residue_attention = {}
265
  for attn_value, res_1, res_2 in attention_unidirectional:
266
  try:
267
+ if not ec_number:
268
+ coord_1 = chain[res_1]["CA"].coord.tolist()
269
+ coord_2 = chain[res_2]["CA"].coord.tolist()
270
+ else:
271
+ if res_1 < 4:
272
+ coord_1 = ec_number[res_1].coordinate
273
+ else:
274
+ coord_1 = chain[res_1 - 4]["CA"].coord.tolist()
275
+ if res_2 < 4:
276
+ coord_2 = ec_number[res_2].coordinate
277
+ else:
278
+ coord_2 = chain[res_2 - 4]["CA"].coord.tolist()
279
+
280
  except KeyError:
281
  continue
282
 
 
289
  )[:top_n]
290
 
291
  for res, attn_sum in top_n_residues:
292
+ if not ec_number:
293
+ coord = chain[res]["CA"].coord.tolist()
294
+ else:
295
+ if res < 4:
296
+ # Ignore EC tag chars as these can't be labeled
297
+ continue
298
+ else:
299
+ coord = chain[res - 4]["CA"].coord.tolist()
300
  top_residues.append((attn_sum, coord, chain.id, res))
301
 
302
  return attention_pairs, top_residues
hexviz/ec_number.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ECNumber:
2
+ def __init__(self, number, coordinate, color, radius):
3
+ self.number = number
4
+ self.coordinate = coordinate
5
+ self.color = color
6
+ self.radius = radius
7
+
8
+ def __str__(self):
9
+ return (
10
+ f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
11
+ )
hexviz/🧬Attention_Visualization.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import pandas as pd
2
  import py3Dmol
3
  import stmol
@@ -9,10 +12,10 @@ from hexviz.attention import (
9
  get_attention_pairs,
10
  get_chains,
11
  )
 
 
12
  from hexviz.models import Model, ModelType
13
  from hexviz.view import menu_items, select_model, select_pdb, select_protein
14
- from hexviz.config import URL
15
-
16
 
17
  st.set_page_config(layout="centered", menu_items=menu_items)
18
  st.title("Attention Visualization on proteins")
@@ -110,15 +113,60 @@ with right:
110
  )
111
  head = head_one - 1
112
 
113
- ec_class = ""
114
  if selected_model.name == ModelType.ZymCTRL:
 
 
 
 
 
 
115
  try:
116
- ec_class = structure.header["compound"]["1"]["ec"]
117
  except KeyError:
118
  pass
119
- ec_class = st.sidebar.text_input(
120
- "Enzyme classification number fetched from PDB", ec_class
121
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  attention_pairs, top_residues = get_attention_pairs(
@@ -129,6 +177,7 @@ attention_pairs, top_residues = get_attention_pairs(
129
  threshold=min_attn,
130
  model_type=selected_model.name,
131
  top_n=n_highest_resis,
 
132
  )
133
 
134
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
@@ -169,6 +218,15 @@ def get_3dview(pdb):
169
  dashed=False,
170
  )
171
 
 
 
 
 
 
 
 
 
 
172
  if label_resi:
173
  for hl_resi in hl_resi_list:
174
  xyzview.addResLabels(
 
1
+ import re
2
+
3
+ import numpy as np
4
  import pandas as pd
5
  import py3Dmol
6
  import stmol
 
12
  get_attention_pairs,
13
  get_chains,
14
  )
15
+ from hexviz.config import URL
16
+ from hexviz.ec_number import ECNumber
17
  from hexviz.models import Model, ModelType
18
  from hexviz.view import menu_items, select_model, select_pdb, select_protein
 
 
19
 
20
  st.set_page_config(layout="centered", menu_items=menu_items)
21
  st.title("Attention Visualization on proteins")
 
113
  )
114
  head = head_one - 1
115
 
116
+ ec_number = ""
117
  if selected_model.name == ModelType.ZymCTRL:
118
+ st.sidebar.markdown(
119
+ """
120
+ ZymCTRL EC number
121
+ ---
122
+ """
123
+ )
124
  try:
125
+ ec_number = structure.header["compound"]["1"]["ec"]
126
  except KeyError:
127
  pass
128
+ ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
129
+
130
+ # Validate EC number
131
+ if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
132
+ st.sidebar.error(
133
+ "Please enter a valid Enzyme Commission number in the format of 4 integers separated by periods (e.g., 1.2.3.21)"
134
+ )
135
+
136
+ if ec_number:
137
+ if selected_chains:
138
+ all_chains = [
139
+ ch for ch in structure.get_chains() if ch.id in selected_chains
140
+ ]
141
+ else:
142
+ all_chains = list(structure.get_chains())
143
+ the_chain = all_chains[0]
144
+ res_1 = the_chain[1]["CA"].coord.tolist()
145
+ res_2 = the_chain[2]["CA"].coord.tolist()
146
+
147
+ # Calculate the vector from res_1 to res_2
148
+ vector = [res_2[i] - res_1[i] for i in range(3)]
149
+
150
+ # Reverse the vector
151
+ reverse_vector = [-v for v in vector]
152
+
153
+ # Normalize the reverse vector
154
+ reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
155
+ reverse_vector
156
+ )
157
+ radius = 1
158
+ coordinates = [
159
+ [res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
160
+ for i in range(4)
161
+ ]
162
+ colors = ["blue", "green", "orange", "red"]
163
+ EC_numbers = ec_number.split(".")
164
+ EC_tag = [
165
+ ECNumber(number=num, coordinate=coord, color=color, radius=radius)
166
+ for num, coord, color in zip(EC_numbers, coordinates, colors)
167
+ ]
168
+ EC_colored = [f":{color}[{EC.number}]" for EC, color in zip(EC_tag, colors)]
169
+ st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
170
 
171
 
172
  attention_pairs, top_residues = get_attention_pairs(
 
177
  threshold=min_attn,
178
  model_type=selected_model.name,
179
  top_n=n_highest_resis,
180
+ ec_number=EC_tag if ec_number else None,
181
  )
182
 
183
  sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
 
218
  dashed=False,
219
  )
220
 
221
+ if selected_model.name == ModelType.ZymCTRL and ec_number:
222
+ for EC_num in EC_tag:
223
+ stmol.add_sphere(
224
+ xyzview,
225
+ spcenter=EC_num.coordinate,
226
+ radius=EC_num.radius,
227
+ spColor=EC_num.color,
228
+ )
229
+
230
  if label_resi:
231
  for hl_resi in hl_resi_list:
232
  xyzview.addResLabels(