Update modeling_fastesm.py
Browse files- modeling_fastesm.py +190 -35
modeling_fastesm.py
CHANGED
@@ -19,7 +19,6 @@ from transformers.models.esm.modeling_esm import (
|
|
19 |
EsmLMHead,
|
20 |
EsmSelfOutput,
|
21 |
EsmClassificationHead,
|
22 |
-
create_position_ids_from_input_ids,
|
23 |
)
|
24 |
from tqdm.auto import tqdm
|
25 |
|
@@ -82,6 +81,58 @@ def apply_rotary_pos_emb(x, cos, sin):
|
|
82 |
return (x * cos) + (rotate_half(x) * sin)
|
83 |
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
class RotaryEmbedding(torch.nn.Module):
|
86 |
"""
|
87 |
Rotary position embeddings based on those in
|
@@ -207,7 +258,18 @@ class EsmSelfAttention(nn.Module):
|
|
207 |
self,
|
208 |
hidden_states: torch.Tensor,
|
209 |
attention_mask: Optional[torch.FloatTensor] = None,
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
|
212 |
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
213 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
@@ -215,15 +277,28 @@ class EsmSelfAttention(nn.Module):
|
|
215 |
if self.position_embedding_type == "rotary":
|
216 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
key_layer,
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
|
229 |
class EsmAttention(nn.Module):
|
@@ -235,15 +310,33 @@ class EsmAttention(nn.Module):
|
|
235 |
|
236 |
def forward(
|
237 |
self,
|
238 |
-
hidden_states,
|
239 |
-
attention_mask=None,
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
242 |
-
|
243 |
hidden_states_ln,
|
244 |
attention_mask,
|
|
|
245 |
)
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
|
249 |
class EsmLayer(nn.Module):
|
@@ -258,14 +351,35 @@ class EsmLayer(nn.Module):
|
|
258 |
|
259 |
def forward(
|
260 |
self,
|
261 |
-
hidden_states,
|
262 |
-
attention_mask=None,
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
hidden_states,
|
266 |
attention_mask,
|
|
|
267 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
layer_output = self.feed_forward_chunk(attention_output)
|
|
|
|
|
|
|
269 |
return layer_output
|
270 |
|
271 |
def feed_forward_chunk(self, attention_output):
|
@@ -285,27 +399,49 @@ class EsmEncoder(nn.Module):
|
|
285 |
|
286 |
def forward(
|
287 |
self,
|
288 |
-
hidden_states,
|
289 |
-
attention_mask=None,
|
290 |
-
output_hidden_states=False,
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
293 |
for layer_module in self.layer:
|
294 |
if output_hidden_states:
|
295 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
296 |
|
297 |
if self.gradient_checkpointing and self.training:
|
298 |
-
|
299 |
layer_module.__call__,
|
300 |
hidden_states,
|
301 |
attention_mask,
|
|
|
302 |
)
|
303 |
else:
|
304 |
-
|
305 |
hidden_states,
|
306 |
attention_mask,
|
|
|
307 |
)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
if self.emb_layer_norm_after:
|
310 |
hidden_states = self.emb_layer_norm_after(hidden_states)
|
311 |
|
@@ -315,6 +451,7 @@ class EsmEncoder(nn.Module):
|
|
315 |
return BaseModelOutputWithPastAndCrossAttentions(
|
316 |
last_hidden_state=hidden_states,
|
317 |
hidden_states=all_hidden_states,
|
|
|
318 |
)
|
319 |
|
320 |
|
@@ -493,18 +630,32 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
493 |
|
494 |
def forward(
|
495 |
self,
|
496 |
-
input_ids: Optional[torch.
|
497 |
attention_mask: Optional[torch.Tensor] = None,
|
498 |
-
position_ids: Optional[torch.
|
499 |
-
inputs_embeds: Optional[torch.
|
500 |
-
output_hidden_states: Optional[bool] = None,
|
501 |
output_attentions: Optional[bool] = None,
|
|
|
|
|
502 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
output_hidden_states = (
|
506 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
507 |
)
|
|
|
508 |
if input_ids is not None and inputs_embeds is not None:
|
509 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
510 |
elif input_ids is not None:
|
@@ -522,10 +673,8 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
522 |
attention_mask=attention_mask,
|
523 |
inputs_embeds=inputs_embeds,
|
524 |
)
|
525 |
-
|
526 |
if attention_mask is not None:
|
527 |
-
# attention_mask shape should be (batch_size, 1, 1, seq_length)
|
528 |
-
# Expand to (batch_size, 1, seq_length, seq_length)
|
529 |
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
530 |
batch_size, 1, seq_length, seq_length
|
531 |
).bool()
|
@@ -536,6 +685,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
536 |
embedding_output,
|
537 |
attention_mask=extended_attention_mask,
|
538 |
output_hidden_states=output_hidden_states,
|
|
|
539 |
)
|
540 |
sequence_output = encoder_outputs.last_hidden_state
|
541 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
@@ -544,6 +694,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
544 |
last_hidden_state=sequence_output,
|
545 |
pooler_output=pooled_output,
|
546 |
hidden_states=encoder_outputs.hidden_states,
|
|
|
547 |
)
|
548 |
|
549 |
|
@@ -572,6 +723,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
572 |
labels: Optional[torch.LongTensor] = None,
|
573 |
output_attentions: Optional[bool] = None,
|
574 |
output_hidden_states: Optional[bool] = None,
|
|
|
575 |
) -> Union[Tuple, MaskedLMOutput]:
|
576 |
outputs = self.esm(
|
577 |
input_ids,
|
@@ -593,6 +745,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
593 |
loss=loss,
|
594 |
logits=prediction_scores,
|
595 |
hidden_states=outputs.hidden_states,
|
|
|
596 |
)
|
597 |
|
598 |
def predict_contacts(self, tokens, attention_mask):
|
@@ -657,6 +810,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
|
|
657 |
loss=loss,
|
658 |
logits=logits,
|
659 |
hidden_states=outputs.hidden_states,
|
|
|
660 |
)
|
661 |
|
662 |
|
@@ -701,6 +855,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
|
|
701 |
loss=loss,
|
702 |
logits=logits,
|
703 |
hidden_states=outputs.hidden_states,
|
|
|
704 |
)
|
705 |
|
706 |
|
|
|
19 |
EsmLMHead,
|
20 |
EsmSelfOutput,
|
21 |
EsmClassificationHead,
|
|
|
22 |
)
|
23 |
from tqdm.auto import tqdm
|
24 |
|
|
|
81 |
return (x * cos) + (rotate_half(x) * sin)
|
82 |
|
83 |
|
84 |
+
def symmetrize(x):
|
85 |
+
"Make layer symmetric in final two dimensions, used for contact prediction."
|
86 |
+
return x + x.transpose(-1, -2)
|
87 |
+
|
88 |
+
|
89 |
+
def average_product_correct(x):
|
90 |
+
"Perform average product correct, used for contact prediction."
|
91 |
+
a1 = x.sum(-1, keepdims=True)
|
92 |
+
a2 = x.sum(-2, keepdims=True)
|
93 |
+
a12 = x.sum((-1, -2), keepdims=True)
|
94 |
+
|
95 |
+
avg = a1 * a2
|
96 |
+
avg.div_(a12) # in-place to reduce memory
|
97 |
+
normalized = x - avg
|
98 |
+
return normalized
|
99 |
+
|
100 |
+
|
101 |
+
class EsmContactPredictionHead(nn.Module):
|
102 |
+
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
in_features: int,
|
107 |
+
bias=True,
|
108 |
+
eos_idx: int = 2,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.in_features = in_features
|
112 |
+
self.eos_idx = eos_idx
|
113 |
+
self.regression = nn.Linear(in_features, 1, bias)
|
114 |
+
self.activation = nn.Sigmoid()
|
115 |
+
|
116 |
+
def forward(self, tokens, attentions):
|
117 |
+
# remove eos token attentions
|
118 |
+
eos_mask = tokens.ne(self.eos_idx).to(attentions)
|
119 |
+
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
|
120 |
+
attentions = attentions * eos_mask[:, None, None, :, :]
|
121 |
+
attentions = attentions[..., :-1, :-1]
|
122 |
+
# remove cls token attentions
|
123 |
+
attentions = attentions[..., 1:, 1:]
|
124 |
+
batch_size, layers, heads, seqlen, _ = attentions.size()
|
125 |
+
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
|
126 |
+
|
127 |
+
# features: batch x channels x tokens x tokens (symmetric)
|
128 |
+
attentions = attentions.to(
|
129 |
+
self.regression.weight.device
|
130 |
+
) # attentions always float32, may need to convert to float16
|
131 |
+
attentions = average_product_correct(symmetrize(attentions))
|
132 |
+
attentions = attentions.permute(0, 2, 3, 1)
|
133 |
+
return self.activation(self.regression(attentions).squeeze(3))
|
134 |
+
|
135 |
+
|
136 |
class RotaryEmbedding(torch.nn.Module):
|
137 |
"""
|
138 |
Rotary position embeddings based on those in
|
|
|
258 |
self,
|
259 |
hidden_states: torch.Tensor,
|
260 |
attention_mask: Optional[torch.FloatTensor] = None,
|
261 |
+
output_attentions: bool = False,
|
262 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
263 |
+
"""Forward pass for self attention.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
hidden_states: Input tensor
|
267 |
+
attention_mask: Optional attention mask
|
268 |
+
output_attentions: Whether to return attention weights
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Output tensor and optionally attention weights
|
272 |
+
"""
|
273 |
query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
|
274 |
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
275 |
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
277 |
if self.position_embedding_type == "rotary":
|
278 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
279 |
|
280 |
+
if output_attentions:
|
281 |
+
# Manual attention computation to get attention weights
|
282 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
283 |
+
if attention_mask is not None:
|
284 |
+
attention_scores = attention_scores + attention_mask
|
285 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
286 |
+
if self.dropout_prob > 0:
|
287 |
+
attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
|
288 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
289 |
+
context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
|
290 |
+
return context_layer, attention_probs
|
291 |
+
else:
|
292 |
+
context_layer = F.scaled_dot_product_attention(
|
293 |
+
query_layer,
|
294 |
+
key_layer,
|
295 |
+
value_layer,
|
296 |
+
attn_mask=attention_mask,
|
297 |
+
dropout_p=self.dropout_prob,
|
298 |
+
scale=1.0
|
299 |
+
)
|
300 |
+
context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
|
301 |
+
return context_layer
|
302 |
|
303 |
|
304 |
class EsmAttention(nn.Module):
|
|
|
310 |
|
311 |
def forward(
|
312 |
self,
|
313 |
+
hidden_states: torch.Tensor,
|
314 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
315 |
+
output_attentions: bool = False,
|
316 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
317 |
+
"""Forward pass for attention layer.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
hidden_states: Input tensor
|
321 |
+
attention_mask: Optional attention mask
|
322 |
+
output_attentions: Whether to return attention weights
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Output tensor and optionally attention weights
|
326 |
+
"""
|
327 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
328 |
+
self_outputs = self.self(
|
329 |
hidden_states_ln,
|
330 |
attention_mask,
|
331 |
+
output_attentions,
|
332 |
)
|
333 |
+
if output_attentions:
|
334 |
+
attention_output, attention_weights = self_outputs
|
335 |
+
attention_output = self.output(attention_output, hidden_states)
|
336 |
+
return attention_output, attention_weights
|
337 |
+
else:
|
338 |
+
attention_output = self_outputs
|
339 |
+
return self.output(attention_output, hidden_states)
|
340 |
|
341 |
|
342 |
class EsmLayer(nn.Module):
|
|
|
351 |
|
352 |
def forward(
|
353 |
self,
|
354 |
+
hidden_states: torch.Tensor,
|
355 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
356 |
+
output_attentions: bool = False,
|
357 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
358 |
+
"""Forward pass for transformer layer.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
hidden_states: Input tensor
|
362 |
+
attention_mask: Optional attention mask
|
363 |
+
output_attentions: Whether to return attention weights
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Output tensor and optionally attention weights
|
367 |
+
"""
|
368 |
+
attention_outputs = self.attention(
|
369 |
hidden_states,
|
370 |
attention_mask,
|
371 |
+
output_attentions,
|
372 |
)
|
373 |
+
if output_attentions:
|
374 |
+
attention_output, attention_weights = attention_outputs
|
375 |
+
else:
|
376 |
+
attention_output = attention_outputs
|
377 |
+
attention_weights = None
|
378 |
+
|
379 |
layer_output = self.feed_forward_chunk(attention_output)
|
380 |
+
|
381 |
+
if output_attentions:
|
382 |
+
return layer_output, attention_weights
|
383 |
return layer_output
|
384 |
|
385 |
def feed_forward_chunk(self, attention_output):
|
|
|
399 |
|
400 |
def forward(
|
401 |
self,
|
402 |
+
hidden_states: torch.Tensor,
|
403 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
404 |
+
output_hidden_states: bool = False,
|
405 |
+
output_attentions: bool = False,
|
406 |
+
) -> BaseModelOutputWithPastAndCrossAttentions:
|
407 |
+
"""Forward pass for transformer encoder.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
hidden_states: Input tensor
|
411 |
+
attention_mask: Optional attention mask
|
412 |
+
output_hidden_states: Whether to return all hidden states
|
413 |
+
output_attentions: Whether to return attention weights
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
BaseModelOutputWithPastAndCrossAttentions containing model outputs
|
417 |
+
"""
|
418 |
all_hidden_states = () if output_hidden_states else None
|
419 |
+
all_attentions = () if output_attentions else None
|
420 |
+
|
421 |
for layer_module in self.layer:
|
422 |
if output_hidden_states:
|
423 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
424 |
|
425 |
if self.gradient_checkpointing and self.training:
|
426 |
+
layer_outputs = self._gradient_checkpointing_func(
|
427 |
layer_module.__call__,
|
428 |
hidden_states,
|
429 |
attention_mask,
|
430 |
+
output_attentions,
|
431 |
)
|
432 |
else:
|
433 |
+
layer_outputs = layer_module(
|
434 |
hidden_states,
|
435 |
attention_mask,
|
436 |
+
output_attentions,
|
437 |
)
|
438 |
|
439 |
+
if output_attentions:
|
440 |
+
hidden_states, attention_weights = layer_outputs
|
441 |
+
all_attentions = all_attentions + (attention_weights,)
|
442 |
+
else:
|
443 |
+
hidden_states = layer_outputs
|
444 |
+
|
445 |
if self.emb_layer_norm_after:
|
446 |
hidden_states = self.emb_layer_norm_after(hidden_states)
|
447 |
|
|
|
451 |
return BaseModelOutputWithPastAndCrossAttentions(
|
452 |
last_hidden_state=hidden_states,
|
453 |
hidden_states=all_hidden_states,
|
454 |
+
attentions=all_attentions,
|
455 |
)
|
456 |
|
457 |
|
|
|
630 |
|
631 |
def forward(
|
632 |
self,
|
633 |
+
input_ids: Optional[torch.LongTensor] = None,
|
634 |
attention_mask: Optional[torch.Tensor] = None,
|
635 |
+
position_ids: Optional[torch.LongTensor] = None,
|
636 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
637 |
output_attentions: Optional[bool] = None,
|
638 |
+
output_hidden_states: Optional[bool] = None,
|
639 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
640 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
641 |
+
"""Forward pass for base model.
|
642 |
+
|
643 |
+
Args:
|
644 |
+
input_ids: Input token IDs
|
645 |
+
attention_mask: Optional attention mask
|
646 |
+
position_ids: Optional position IDs
|
647 |
+
inputs_embeds: Optional input embeddings
|
648 |
+
output_hidden_states: Whether to return all hidden states
|
649 |
+
output_attentions: Whether to return attention weights
|
650 |
+
|
651 |
+
Returns:
|
652 |
+
Model outputs including hidden states and optionally attention weights
|
653 |
+
"""
|
654 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
655 |
output_hidden_states = (
|
656 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
657 |
)
|
658 |
+
|
659 |
if input_ids is not None and inputs_embeds is not None:
|
660 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
661 |
elif input_ids is not None:
|
|
|
673 |
attention_mask=attention_mask,
|
674 |
inputs_embeds=inputs_embeds,
|
675 |
)
|
676 |
+
|
677 |
if attention_mask is not None:
|
|
|
|
|
678 |
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
679 |
batch_size, 1, seq_length, seq_length
|
680 |
).bool()
|
|
|
685 |
embedding_output,
|
686 |
attention_mask=extended_attention_mask,
|
687 |
output_hidden_states=output_hidden_states,
|
688 |
+
output_attentions=output_attentions,
|
689 |
)
|
690 |
sequence_output = encoder_outputs.last_hidden_state
|
691 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
694 |
last_hidden_state=sequence_output,
|
695 |
pooler_output=pooled_output,
|
696 |
hidden_states=encoder_outputs.hidden_states,
|
697 |
+
attentions=encoder_outputs.attentions,
|
698 |
)
|
699 |
|
700 |
|
|
|
723 |
labels: Optional[torch.LongTensor] = None,
|
724 |
output_attentions: Optional[bool] = None,
|
725 |
output_hidden_states: Optional[bool] = None,
|
726 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
727 |
) -> Union[Tuple, MaskedLMOutput]:
|
728 |
outputs = self.esm(
|
729 |
input_ids,
|
|
|
745 |
loss=loss,
|
746 |
logits=prediction_scores,
|
747 |
hidden_states=outputs.hidden_states,
|
748 |
+
attentions=outputs.attentions,
|
749 |
)
|
750 |
|
751 |
def predict_contacts(self, tokens, attention_mask):
|
|
|
810 |
loss=loss,
|
811 |
logits=logits,
|
812 |
hidden_states=outputs.hidden_states,
|
813 |
+
attentions=outputs.attentions,
|
814 |
)
|
815 |
|
816 |
|
|
|
855 |
loss=loss,
|
856 |
logits=logits,
|
857 |
hidden_states=outputs.hidden_states,
|
858 |
+
attentions=outputs.attentions,
|
859 |
)
|
860 |
|
861 |
|