aksell commited on
Commit
e13695f
·
1 Parent(s): 214c50a

Don't remove special tokens in identify interesting heads

Browse files
hexviz/attention.py CHANGED
@@ -99,7 +99,11 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
99
 
100
 
101
  @st.cache
102
- def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
 
 
 
 
103
  """
104
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
105
  """
@@ -111,8 +115,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
111
  inputs = torch.tensor(token_idxs).unsqueeze(0)
112
  with torch.no_grad():
113
  attentions = model(inputs)[-1]
 
114
  # Remove attention from <CLS> (first) and <SEP> (last) token
115
- attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
116
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
117
 
118
  elif model_type == ModelType.ZymCTRL:
@@ -141,8 +146,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
141
  inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
142
  with torch.no_grad():
143
  attentions = model(inputs, output_attentions=True)[-1]
 
144
  # Remove attention from <CLS> (first) and <SEP> (last) token
145
- attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
146
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
147
 
148
  elif model_type == ModelType.PROT_T5:
@@ -155,8 +161,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
155
  -1
156
  ] # Do you need an attention mask?
157
 
158
- # Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
159
- attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
 
160
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
161
 
162
  else:
 
99
 
100
 
101
  @st.cache
102
+ def get_attention(
103
+ sequence: str,
104
+ model_type: ModelType = ModelType.TAPE_BERT,
105
+ remove_special_tokens: bool = True,
106
+ ):
107
  """
108
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
109
  """
 
115
  inputs = torch.tensor(token_idxs).unsqueeze(0)
116
  with torch.no_grad():
117
  attentions = model(inputs)[-1]
118
+ if remove_special_tokens:
119
  # Remove attention from <CLS> (first) and <SEP> (last) token
120
+ attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
121
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
122
 
123
  elif model_type == ModelType.ZymCTRL:
 
146
  inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
147
  with torch.no_grad():
148
  attentions = model(inputs, output_attentions=True)[-1]
149
+ if remove_special_tokens:
150
  # Remove attention from <CLS> (first) and <SEP> (last) token
151
+ attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
152
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
153
 
154
  elif model_type == ModelType.PROT_T5:
 
161
  -1
162
  ] # Do you need an attention mask?
163
 
164
+ if remove_special_tokens:
165
+ # Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
166
+ attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
167
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
168
 
169
  else:
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -69,7 +69,11 @@ st.markdown(
69
 
70
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
71
  # Attention values will change depending on what we do.
72
- attention = get_attention(sequence=truncated_sequence, model_type=selected_model.name)
 
 
 
 
73
  st.write(attention.shape)
74
 
75
  fig = plot_tiled_heatmap(
 
69
 
70
  # TODO: Decide if you should get attention for the full sequence or just the truncated sequence
71
  # Attention values will change depending on what we do.
72
+ attention = get_attention(
73
+ sequence=truncated_sequence,
74
+ model_type=selected_model.name,
75
+ remove_special_tokens=False,
76
+ )
77
  st.write(attention.shape)
78
 
79
  fig = plot_tiled_heatmap(