Spaces:
Sleeping
Sleeping
Don't remove special tokens in identify interesting heads
Browse files
hexviz/attention.py
CHANGED
@@ -99,7 +99,11 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
99 |
|
100 |
|
101 |
@st.cache
|
102 |
-
def get_attention(
|
|
|
|
|
|
|
|
|
103 |
"""
|
104 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
105 |
"""
|
@@ -111,8 +115,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
|
|
111 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
112 |
with torch.no_grad():
|
113 |
attentions = model(inputs)[-1]
|
|
|
114 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
115 |
-
|
116 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
117 |
|
118 |
elif model_type == ModelType.ZymCTRL:
|
@@ -141,8 +146,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
|
|
141 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
142 |
with torch.no_grad():
|
143 |
attentions = model(inputs, output_attentions=True)[-1]
|
|
|
144 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
145 |
-
|
146 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
147 |
|
148 |
elif model_type == ModelType.PROT_T5:
|
@@ -155,8 +161,9 @@ def get_attention(sequence: str, model_type: ModelType = ModelType.TAPE_BERT):
|
|
155 |
-1
|
156 |
] # Do you need an attention mask?
|
157 |
|
158 |
-
|
159 |
-
|
|
|
160 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
161 |
|
162 |
else:
|
|
|
99 |
|
100 |
|
101 |
@st.cache
|
102 |
+
def get_attention(
|
103 |
+
sequence: str,
|
104 |
+
model_type: ModelType = ModelType.TAPE_BERT,
|
105 |
+
remove_special_tokens: bool = True,
|
106 |
+
):
|
107 |
"""
|
108 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
109 |
"""
|
|
|
115 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
116 |
with torch.no_grad():
|
117 |
attentions = model(inputs)[-1]
|
118 |
+
if remove_special_tokens:
|
119 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
120 |
+
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
121 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
122 |
|
123 |
elif model_type == ModelType.ZymCTRL:
|
|
|
146 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
147 |
with torch.no_grad():
|
148 |
attentions = model(inputs, output_attentions=True)[-1]
|
149 |
+
if remove_special_tokens:
|
150 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
151 |
+
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
152 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
153 |
|
154 |
elif model_type == ModelType.PROT_T5:
|
|
|
161 |
-1
|
162 |
] # Do you need an attention mask?
|
163 |
|
164 |
+
if remove_special_tokens:
|
165 |
+
# Remove attention to <pad> (first) and <extra_id_1>, <extra_id_2> (last) tokens
|
166 |
+
attentions = [attention[:, :, 3:-3, 3:-3] for attention in attentions]
|
167 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
168 |
|
169 |
else:
|
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -69,7 +69,11 @@ st.markdown(
|
|
69 |
|
70 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
71 |
# Attention values will change depending on what we do.
|
72 |
-
attention = get_attention(
|
|
|
|
|
|
|
|
|
73 |
st.write(attention.shape)
|
74 |
|
75 |
fig = plot_tiled_heatmap(
|
|
|
69 |
|
70 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
71 |
# Attention values will change depending on what we do.
|
72 |
+
attention = get_attention(
|
73 |
+
sequence=truncated_sequence,
|
74 |
+
model_type=selected_model.name,
|
75 |
+
remove_special_tokens=False,
|
76 |
+
)
|
77 |
st.write(attention.shape)
|
78 |
|
79 |
fig = plot_tiled_heatmap(
|