aksell commited on
Commit
1ce487e
·
1 Parent(s): 703d569

Fix reomval of special tokens for T5

Browse files

There is only 1 extra token added now.

Files changed (1) hide show
  1. hexviz/attention.py +2 -2
hexviz/attention.py CHANGED
@@ -162,8 +162,8 @@ def get_attention(
162
  ] # Do you need an attention mask?
163
 
164
  if remove_special_tokens:
165
- # Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
166
- attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
167
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
168
 
169
  else:
 
162
  ] # Do you need an attention mask?
163
 
164
  if remove_special_tokens:
165
+ # Remove attention to </s> (last) token
166
+ attentions = [attention[:, :, :-1, :-1] for attention in attentions]
167
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
168
 
169
  else: