aksell commited on
Commit
2e81150
·
1 Parent(s): 6d3f484

Fix removal of special tokens for ZymCTRL

Browse files
Files changed (1) hide show
  1. hexviz/attention.py +29 -26
hexviz/attention.py CHANGED
@@ -91,27 +91,15 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
91
  return cleaned_sequence, None
92
 
93
 
94
- def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
95
- tokens = tokenizer.tokenize(sequence)
96
 
97
- indices_to_remove = [
98
- i for i, token in enumerate(tokens) if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
99
- ]
 
100
 
101
- new_attentions = []
102
-
103
- for attentions in attentions_tuple:
104
- # Remove rows and columns corresponding to special tokens and periods
105
- for idx in sorted(indices_to_remove, reverse=True):
106
- attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
107
- attentions = torch.cat(
108
- (attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
109
- )
110
-
111
- # Append the modified attentions tensor to the new_attentions list
112
- new_attentions.append(attentions)
113
-
114
- return new_attentions, [token for i, token in enumerate(tokens) if i not in indices_to_remove]
115
 
116
 
117
  @st.cache
@@ -131,12 +119,17 @@ def get_attention(
131
  tokenizer, model = get_tape_bert()
132
  token_idxs = tokenizer.encode(sequence).tolist()
133
  inputs = torch.tensor(token_idxs).unsqueeze(0)
 
134
  with torch.no_grad():
135
  attentions = model(inputs)[-1]
 
 
 
136
  if remove_special_tokens:
137
  # Remove attention from <CLS> (first) and <SEP> (last) token
138
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
139
- inputs = inputs[:, 1:-1]
 
140
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
141
 
142
  elif model_type == ModelType.ZymCTRL:
@@ -151,9 +144,18 @@ def get_attention(
151
  outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
152
  attentions = outputs.attentions
153
 
 
 
154
  if ec_number and remove_special_tokens:
155
  # Remove attention to special tokens and periods separating EC number components
156
- attentions, inputs = remove_special_tokens_and_periods(attentions, sequence, tokenizer)
 
 
 
 
 
 
 
157
 
158
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
159
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
@@ -169,10 +171,12 @@ def get_attention(
169
  with torch.no_grad():
170
  attentions = model(inputs, output_attentions=True)[-1]
171
 
 
172
  if remove_special_tokens:
173
  # Remove attention from <CLS> (first) and <SEP> (last) token
174
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
175
- inputs = inputs[:, 1:-1]
 
176
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
177
 
178
  elif model_type == ModelType.PROT_T5:
@@ -183,19 +187,18 @@ def get_attention(
183
  with torch.no_grad():
184
  attentions = model(inputs, output_attentions=True)[-1]
185
 
 
186
  if remove_special_tokens:
187
  # Remove attention to </s> (last) token
188
  attentions = [attention[:, :, :-1, :-1] for attention in attentions]
189
- inputs = inputs[:, :-1]
190
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
191
 
192
  else:
193
  raise ValueError(f"Model {model_type} not supported")
194
 
195
- input_ids_list = inputs.squeeze().tolist()
196
- tokens = tokenizer.convert_ids_to_tokens(input_ids_list)
197
  # Transfer to CPU to avoid issues with streamlit caching
198
- return attentions.cpu(), tokens
199
 
200
 
201
  def unidirectional_avg_filtered(attention, layer, head, threshold):
 
91
  return cleaned_sequence, None
92
 
93
 
94
+ def remove_tokens(attentions, tokens, tokens_to_remove):
95
+ indices_to_remove = [i for i, token in enumerate(tokens) if token in tokens_to_remove]
96
 
97
+ # Remove rows and columns corresponding to special tokens and periods
98
+ for idx in sorted(indices_to_remove, reverse=True):
99
+ attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
100
+ attentions = torch.cat((attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3)
101
 
102
+ return attentions
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  @st.cache
 
119
  tokenizer, model = get_tape_bert()
120
  token_idxs = tokenizer.encode(sequence).tolist()
121
  inputs = torch.tensor(token_idxs).unsqueeze(0)
122
+
123
  with torch.no_grad():
124
  attentions = model(inputs)[-1]
125
+
126
+ tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
127
+
128
  if remove_special_tokens:
129
  # Remove attention from <CLS> (first) and <SEP> (last) token
130
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
131
+ tokenized_sequence = tokenized_sequence[1:-1]
132
+
133
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
134
 
135
  elif model_type == ModelType.ZymCTRL:
 
144
  outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
145
  attentions = outputs.attentions
146
 
147
+ tokenized_sequence = tokenizer.convert_ids_to_tokens(tokenizer.encode(sequence))
148
+
149
  if ec_number and remove_special_tokens:
150
  # Remove attention to special tokens and periods separating EC number components
151
+ tokens_to_remove = [".", "<sep>", "<start>", "<end>", "<pad>"]
152
+ attentions = [
153
+ remove_tokens(attention, tokenized_sequence, tokens_to_remove)
154
+ for attention in attentions
155
+ ]
156
+ tokenized_sequence = [
157
+ token for token in tokenized_sequence if token not in tokens_to_remove
158
+ ]
159
 
160
  # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
161
  attention_squeezed = [torch.squeeze(attention) for attention in attentions]
 
171
  with torch.no_grad():
172
  attentions = model(inputs, output_attentions=True)[-1]
173
 
174
+ tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
175
  if remove_special_tokens:
176
  # Remove attention from <CLS> (first) and <SEP> (last) token
177
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
178
+ tokenized_sequence = tokenized_sequence[1:-1]
179
+
180
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
181
 
182
  elif model_type == ModelType.PROT_T5:
 
187
  with torch.no_grad():
188
  attentions = model(inputs, output_attentions=True)[-1]
189
 
190
+ tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
191
  if remove_special_tokens:
192
  # Remove attention to </s> (last) token
193
  attentions = [attention[:, :, :-1, :-1] for attention in attentions]
194
+ tokenized_sequence = inputs[:-1]
195
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
196
 
197
  else:
198
  raise ValueError(f"Model {model_type} not supported")
199
 
 
 
200
  # Transfer to CPU to avoid issues with streamlit caching
201
+ return attentions.cpu(), tokenized_sequence
202
 
203
 
204
  def unidirectional_avg_filtered(attention, layer, head, threshold):