Spaces:
Sleeping
Sleeping
Fix removal of special tokens for ZymCTRL
Browse files- hexviz/attention.py +29 -26
hexviz/attention.py
CHANGED
@@ -91,27 +91,15 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
91 |
return cleaned_sequence, None
|
92 |
|
93 |
|
94 |
-
def
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
for attentions in attentions_tuple:
|
104 |
-
# Remove rows and columns corresponding to special tokens and periods
|
105 |
-
for idx in sorted(indices_to_remove, reverse=True):
|
106 |
-
attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
|
107 |
-
attentions = torch.cat(
|
108 |
-
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
109 |
-
)
|
110 |
-
|
111 |
-
# Append the modified attentions tensor to the new_attentions list
|
112 |
-
new_attentions.append(attentions)
|
113 |
-
|
114 |
-
return new_attentions, [token for i, token in enumerate(tokens) if i not in indices_to_remove]
|
115 |
|
116 |
|
117 |
@st.cache
|
@@ -131,12 +119,17 @@ def get_attention(
|
|
131 |
tokenizer, model = get_tape_bert()
|
132 |
token_idxs = tokenizer.encode(sequence).tolist()
|
133 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
|
|
134 |
with torch.no_grad():
|
135 |
attentions = model(inputs)[-1]
|
|
|
|
|
|
|
136 |
if remove_special_tokens:
|
137 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
138 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
139 |
-
|
|
|
140 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
141 |
|
142 |
elif model_type == ModelType.ZymCTRL:
|
@@ -151,9 +144,18 @@ def get_attention(
|
|
151 |
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
152 |
attentions = outputs.attentions
|
153 |
|
|
|
|
|
154 |
if ec_number and remove_special_tokens:
|
155 |
# Remove attention to special tokens and periods separating EC number components
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
159 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
@@ -169,10 +171,12 @@ def get_attention(
|
|
169 |
with torch.no_grad():
|
170 |
attentions = model(inputs, output_attentions=True)[-1]
|
171 |
|
|
|
172 |
if remove_special_tokens:
|
173 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
174 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
175 |
-
|
|
|
176 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
177 |
|
178 |
elif model_type == ModelType.PROT_T5:
|
@@ -183,19 +187,18 @@ def get_attention(
|
|
183 |
with torch.no_grad():
|
184 |
attentions = model(inputs, output_attentions=True)[-1]
|
185 |
|
|
|
186 |
if remove_special_tokens:
|
187 |
# Remove attention to </s> (last) token
|
188 |
attentions = [attention[:, :, :-1, :-1] for attention in attentions]
|
189 |
-
|
190 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
191 |
|
192 |
else:
|
193 |
raise ValueError(f"Model {model_type} not supported")
|
194 |
|
195 |
-
input_ids_list = inputs.squeeze().tolist()
|
196 |
-
tokens = tokenizer.convert_ids_to_tokens(input_ids_list)
|
197 |
# Transfer to CPU to avoid issues with streamlit caching
|
198 |
-
return attentions.cpu(),
|
199 |
|
200 |
|
201 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|
|
|
91 |
return cleaned_sequence, None
|
92 |
|
93 |
|
94 |
+
def remove_tokens(attentions, tokens, tokens_to_remove):
|
95 |
+
indices_to_remove = [i for i, token in enumerate(tokens) if token in tokens_to_remove]
|
96 |
|
97 |
+
# Remove rows and columns corresponding to special tokens and periods
|
98 |
+
for idx in sorted(indices_to_remove, reverse=True):
|
99 |
+
attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
|
100 |
+
attentions = torch.cat((attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3)
|
101 |
|
102 |
+
return attentions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
@st.cache
|
|
|
119 |
tokenizer, model = get_tape_bert()
|
120 |
token_idxs = tokenizer.encode(sequence).tolist()
|
121 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
122 |
+
|
123 |
with torch.no_grad():
|
124 |
attentions = model(inputs)[-1]
|
125 |
+
|
126 |
+
tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
|
127 |
+
|
128 |
if remove_special_tokens:
|
129 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
130 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
131 |
+
tokenized_sequence = tokenized_sequence[1:-1]
|
132 |
+
|
133 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
134 |
|
135 |
elif model_type == ModelType.ZymCTRL:
|
|
|
144 |
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
145 |
attentions = outputs.attentions
|
146 |
|
147 |
+
tokenized_sequence = tokenizer.convert_ids_to_tokens(tokenizer.encode(sequence))
|
148 |
+
|
149 |
if ec_number and remove_special_tokens:
|
150 |
# Remove attention to special tokens and periods separating EC number components
|
151 |
+
tokens_to_remove = [".", "<sep>", "<start>", "<end>", "<pad>"]
|
152 |
+
attentions = [
|
153 |
+
remove_tokens(attention, tokenized_sequence, tokens_to_remove)
|
154 |
+
for attention in attentions
|
155 |
+
]
|
156 |
+
tokenized_sequence = [
|
157 |
+
token for token in tokenized_sequence if token not in tokens_to_remove
|
158 |
+
]
|
159 |
|
160 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
161 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
|
|
171 |
with torch.no_grad():
|
172 |
attentions = model(inputs, output_attentions=True)[-1]
|
173 |
|
174 |
+
tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
|
175 |
if remove_special_tokens:
|
176 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
177 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
178 |
+
tokenized_sequence = tokenized_sequence[1:-1]
|
179 |
+
|
180 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
181 |
|
182 |
elif model_type == ModelType.PROT_T5:
|
|
|
187 |
with torch.no_grad():
|
188 |
attentions = model(inputs, output_attentions=True)[-1]
|
189 |
|
190 |
+
tokenized_sequence = tokenizer.convert_ids_to_tokens(token_idxs)
|
191 |
if remove_special_tokens:
|
192 |
# Remove attention to </s> (last) token
|
193 |
attentions = [attention[:, :, :-1, :-1] for attention in attentions]
|
194 |
+
tokenized_sequence = inputs[:-1]
|
195 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
196 |
|
197 |
else:
|
198 |
raise ValueError(f"Model {model_type} not supported")
|
199 |
|
|
|
|
|
200 |
# Transfer to CPU to avoid issues with streamlit caching
|
201 |
+
return attentions.cpu(), tokenized_sequence
|
202 |
|
203 |
|
204 |
def unidirectional_avg_filtered(attention, layer, head, threshold):
|