lhallee commited on
Commit
c98ee00
·
verified ·
1 Parent(s): e22ff6e

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +190 -35
modeling_fastesm.py CHANGED
@@ -19,7 +19,6 @@ from transformers.models.esm.modeling_esm import (
19
  EsmLMHead,
20
  EsmSelfOutput,
21
  EsmClassificationHead,
22
- create_position_ids_from_input_ids,
23
  )
24
  from tqdm.auto import tqdm
25
 
@@ -82,6 +81,58 @@ def apply_rotary_pos_emb(x, cos, sin):
82
  return (x * cos) + (rotate_half(x) * sin)
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  class RotaryEmbedding(torch.nn.Module):
86
  """
87
  Rotary position embeddings based on those in
@@ -207,7 +258,18 @@ class EsmSelfAttention(nn.Module):
207
  self,
208
  hidden_states: torch.Tensor,
209
  attention_mask: Optional[torch.FloatTensor] = None,
210
- ) -> Tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
211
  query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
212
  key_layer = self.transpose_for_scores(self.key(hidden_states))
213
  value_layer = self.transpose_for_scores(self.value(hidden_states))
@@ -215,15 +277,28 @@ class EsmSelfAttention(nn.Module):
215
  if self.position_embedding_type == "rotary":
216
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
217
 
218
- context_layer = F.scaled_dot_product_attention(
219
- query_layer,
220
- key_layer,
221
- value_layer,
222
- attn_mask=attention_mask,
223
- dropout_p=self.dropout_prob,
224
- scale=1.0
225
- )
226
- return rearrange(context_layer, 'b h s d -> b s (h d)')
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
 
229
  class EsmAttention(nn.Module):
@@ -235,15 +310,33 @@ class EsmAttention(nn.Module):
235
 
236
  def forward(
237
  self,
238
- hidden_states,
239
- attention_mask=None,
240
- ):
 
 
 
 
 
 
 
 
 
 
 
241
  hidden_states_ln = self.LayerNorm(hidden_states)
242
- attention_output = self.self(
243
  hidden_states_ln,
244
  attention_mask,
 
245
  )
246
- return self.output(attention_output, hidden_states)
 
 
 
 
 
 
247
 
248
 
249
  class EsmLayer(nn.Module):
@@ -258,14 +351,35 @@ class EsmLayer(nn.Module):
258
 
259
  def forward(
260
  self,
261
- hidden_states,
262
- attention_mask=None,
263
- ):
264
- attention_output = self.attention(
 
 
 
 
 
 
 
 
 
 
 
265
  hidden_states,
266
  attention_mask,
 
267
  )
 
 
 
 
 
 
268
  layer_output = self.feed_forward_chunk(attention_output)
 
 
 
269
  return layer_output
270
 
271
  def feed_forward_chunk(self, attention_output):
@@ -285,27 +399,49 @@ class EsmEncoder(nn.Module):
285
 
286
  def forward(
287
  self,
288
- hidden_states,
289
- attention_mask=None,
290
- output_hidden_states=False,
291
- ):
 
 
 
 
 
 
 
 
 
 
 
 
292
  all_hidden_states = () if output_hidden_states else None
 
 
293
  for layer_module in self.layer:
294
  if output_hidden_states:
295
  all_hidden_states = all_hidden_states + (hidden_states,)
296
 
297
  if self.gradient_checkpointing and self.training:
298
- hidden_states = self._gradient_checkpointing_func(
299
  layer_module.__call__,
300
  hidden_states,
301
  attention_mask,
 
302
  )
303
  else:
304
- hidden_states = layer_module(
305
  hidden_states,
306
  attention_mask,
 
307
  )
308
 
 
 
 
 
 
 
309
  if self.emb_layer_norm_after:
310
  hidden_states = self.emb_layer_norm_after(hidden_states)
311
 
@@ -315,6 +451,7 @@ class EsmEncoder(nn.Module):
315
  return BaseModelOutputWithPastAndCrossAttentions(
316
  last_hidden_state=hidden_states,
317
  hidden_states=all_hidden_states,
 
318
  )
319
 
320
 
@@ -493,18 +630,32 @@ class FastEsmModel(FastEsmPreTrainedModel):
493
 
494
  def forward(
495
  self,
496
- input_ids: Optional[torch.Tensor] = None,
497
  attention_mask: Optional[torch.Tensor] = None,
498
- position_ids: Optional[torch.Tensor] = None,
499
- inputs_embeds: Optional[torch.Tensor] = None,
500
- output_hidden_states: Optional[bool] = None,
501
  output_attentions: Optional[bool] = None,
 
 
502
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
503
- if output_attentions is not None:
504
- raise ValueError("output_attentions is not supported by F.scaled_dot_product_attention")
 
 
 
 
 
 
 
 
 
 
 
 
505
  output_hidden_states = (
506
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
507
  )
 
508
  if input_ids is not None and inputs_embeds is not None:
509
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
510
  elif input_ids is not None:
@@ -522,10 +673,8 @@ class FastEsmModel(FastEsmPreTrainedModel):
522
  attention_mask=attention_mask,
523
  inputs_embeds=inputs_embeds,
524
  )
525
- # Prepare attention mask
526
  if attention_mask is not None:
527
- # attention_mask shape should be (batch_size, 1, 1, seq_length)
528
- # Expand to (batch_size, 1, seq_length, seq_length)
529
  extended_attention_mask = attention_mask[:, None, None, :].expand(
530
  batch_size, 1, seq_length, seq_length
531
  ).bool()
@@ -536,6 +685,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
536
  embedding_output,
537
  attention_mask=extended_attention_mask,
538
  output_hidden_states=output_hidden_states,
 
539
  )
540
  sequence_output = encoder_outputs.last_hidden_state
541
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@@ -544,6 +694,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
544
  last_hidden_state=sequence_output,
545
  pooler_output=pooled_output,
546
  hidden_states=encoder_outputs.hidden_states,
 
547
  )
548
 
549
 
@@ -572,6 +723,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
572
  labels: Optional[torch.LongTensor] = None,
573
  output_attentions: Optional[bool] = None,
574
  output_hidden_states: Optional[bool] = None,
 
575
  ) -> Union[Tuple, MaskedLMOutput]:
576
  outputs = self.esm(
577
  input_ids,
@@ -593,6 +745,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
593
  loss=loss,
594
  logits=prediction_scores,
595
  hidden_states=outputs.hidden_states,
 
596
  )
597
 
598
  def predict_contacts(self, tokens, attention_mask):
@@ -657,6 +810,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
657
  loss=loss,
658
  logits=logits,
659
  hidden_states=outputs.hidden_states,
 
660
  )
661
 
662
 
@@ -701,6 +855,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
701
  loss=loss,
702
  logits=logits,
703
  hidden_states=outputs.hidden_states,
 
704
  )
705
 
706
 
 
19
  EsmLMHead,
20
  EsmSelfOutput,
21
  EsmClassificationHead,
 
22
  )
23
  from tqdm.auto import tqdm
24
 
 
81
  return (x * cos) + (rotate_half(x) * sin)
82
 
83
 
84
+ def symmetrize(x):
85
+ "Make layer symmetric in final two dimensions, used for contact prediction."
86
+ return x + x.transpose(-1, -2)
87
+
88
+
89
+ def average_product_correct(x):
90
+ "Perform average product correct, used for contact prediction."
91
+ a1 = x.sum(-1, keepdims=True)
92
+ a2 = x.sum(-2, keepdims=True)
93
+ a12 = x.sum((-1, -2), keepdims=True)
94
+
95
+ avg = a1 * a2
96
+ avg.div_(a12) # in-place to reduce memory
97
+ normalized = x - avg
98
+ return normalized
99
+
100
+
101
+ class EsmContactPredictionHead(nn.Module):
102
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
103
+
104
+ def __init__(
105
+ self,
106
+ in_features: int,
107
+ bias=True,
108
+ eos_idx: int = 2,
109
+ ):
110
+ super().__init__()
111
+ self.in_features = in_features
112
+ self.eos_idx = eos_idx
113
+ self.regression = nn.Linear(in_features, 1, bias)
114
+ self.activation = nn.Sigmoid()
115
+
116
+ def forward(self, tokens, attentions):
117
+ # remove eos token attentions
118
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
119
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
120
+ attentions = attentions * eos_mask[:, None, None, :, :]
121
+ attentions = attentions[..., :-1, :-1]
122
+ # remove cls token attentions
123
+ attentions = attentions[..., 1:, 1:]
124
+ batch_size, layers, heads, seqlen, _ = attentions.size()
125
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
126
+
127
+ # features: batch x channels x tokens x tokens (symmetric)
128
+ attentions = attentions.to(
129
+ self.regression.weight.device
130
+ ) # attentions always float32, may need to convert to float16
131
+ attentions = average_product_correct(symmetrize(attentions))
132
+ attentions = attentions.permute(0, 2, 3, 1)
133
+ return self.activation(self.regression(attentions).squeeze(3))
134
+
135
+
136
  class RotaryEmbedding(torch.nn.Module):
137
  """
138
  Rotary position embeddings based on those in
 
258
  self,
259
  hidden_states: torch.Tensor,
260
  attention_mask: Optional[torch.FloatTensor] = None,
261
+ output_attentions: bool = False,
262
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
263
+ """Forward pass for self attention.
264
+
265
+ Args:
266
+ hidden_states: Input tensor
267
+ attention_mask: Optional attention mask
268
+ output_attentions: Whether to return attention weights
269
+
270
+ Returns:
271
+ Output tensor and optionally attention weights
272
+ """
273
  query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
274
  key_layer = self.transpose_for_scores(self.key(hidden_states))
275
  value_layer = self.transpose_for_scores(self.value(hidden_states))
 
277
  if self.position_embedding_type == "rotary":
278
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
279
 
280
+ if output_attentions:
281
+ # Manual attention computation to get attention weights
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+ if attention_mask is not None:
284
+ attention_scores = attention_scores + attention_mask
285
+ attention_probs = F.softmax(attention_scores, dim=-1)
286
+ if self.dropout_prob > 0:
287
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
288
+ context_layer = torch.matmul(attention_probs, value_layer)
289
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
290
+ return context_layer, attention_probs
291
+ else:
292
+ context_layer = F.scaled_dot_product_attention(
293
+ query_layer,
294
+ key_layer,
295
+ value_layer,
296
+ attn_mask=attention_mask,
297
+ dropout_p=self.dropout_prob,
298
+ scale=1.0
299
+ )
300
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
301
+ return context_layer
302
 
303
 
304
  class EsmAttention(nn.Module):
 
310
 
311
  def forward(
312
  self,
313
+ hidden_states: torch.Tensor,
314
+ attention_mask: Optional[torch.FloatTensor] = None,
315
+ output_attentions: bool = False,
316
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
317
+ """Forward pass for attention layer.
318
+
319
+ Args:
320
+ hidden_states: Input tensor
321
+ attention_mask: Optional attention mask
322
+ output_attentions: Whether to return attention weights
323
+
324
+ Returns:
325
+ Output tensor and optionally attention weights
326
+ """
327
  hidden_states_ln = self.LayerNorm(hidden_states)
328
+ self_outputs = self.self(
329
  hidden_states_ln,
330
  attention_mask,
331
+ output_attentions,
332
  )
333
+ if output_attentions:
334
+ attention_output, attention_weights = self_outputs
335
+ attention_output = self.output(attention_output, hidden_states)
336
+ return attention_output, attention_weights
337
+ else:
338
+ attention_output = self_outputs
339
+ return self.output(attention_output, hidden_states)
340
 
341
 
342
  class EsmLayer(nn.Module):
 
351
 
352
  def forward(
353
  self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ output_attentions: bool = False,
357
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
358
+ """Forward pass for transformer layer.
359
+
360
+ Args:
361
+ hidden_states: Input tensor
362
+ attention_mask: Optional attention mask
363
+ output_attentions: Whether to return attention weights
364
+
365
+ Returns:
366
+ Output tensor and optionally attention weights
367
+ """
368
+ attention_outputs = self.attention(
369
  hidden_states,
370
  attention_mask,
371
+ output_attentions,
372
  )
373
+ if output_attentions:
374
+ attention_output, attention_weights = attention_outputs
375
+ else:
376
+ attention_output = attention_outputs
377
+ attention_weights = None
378
+
379
  layer_output = self.feed_forward_chunk(attention_output)
380
+
381
+ if output_attentions:
382
+ return layer_output, attention_weights
383
  return layer_output
384
 
385
  def feed_forward_chunk(self, attention_output):
 
399
 
400
  def forward(
401
  self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.FloatTensor] = None,
404
+ output_hidden_states: bool = False,
405
+ output_attentions: bool = False,
406
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
407
+ """Forward pass for transformer encoder.
408
+
409
+ Args:
410
+ hidden_states: Input tensor
411
+ attention_mask: Optional attention mask
412
+ output_hidden_states: Whether to return all hidden states
413
+ output_attentions: Whether to return attention weights
414
+
415
+ Returns:
416
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
417
+ """
418
  all_hidden_states = () if output_hidden_states else None
419
+ all_attentions = () if output_attentions else None
420
+
421
  for layer_module in self.layer:
422
  if output_hidden_states:
423
  all_hidden_states = all_hidden_states + (hidden_states,)
424
 
425
  if self.gradient_checkpointing and self.training:
426
+ layer_outputs = self._gradient_checkpointing_func(
427
  layer_module.__call__,
428
  hidden_states,
429
  attention_mask,
430
+ output_attentions,
431
  )
432
  else:
433
+ layer_outputs = layer_module(
434
  hidden_states,
435
  attention_mask,
436
+ output_attentions,
437
  )
438
 
439
+ if output_attentions:
440
+ hidden_states, attention_weights = layer_outputs
441
+ all_attentions = all_attentions + (attention_weights,)
442
+ else:
443
+ hidden_states = layer_outputs
444
+
445
  if self.emb_layer_norm_after:
446
  hidden_states = self.emb_layer_norm_after(hidden_states)
447
 
 
451
  return BaseModelOutputWithPastAndCrossAttentions(
452
  last_hidden_state=hidden_states,
453
  hidden_states=all_hidden_states,
454
+ attentions=all_attentions,
455
  )
456
 
457
 
 
630
 
631
  def forward(
632
  self,
633
+ input_ids: Optional[torch.LongTensor] = None,
634
  attention_mask: Optional[torch.Tensor] = None,
635
+ position_ids: Optional[torch.LongTensor] = None,
636
+ inputs_embeds: Optional[torch.FloatTensor] = None,
 
637
  output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
640
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
641
+ """Forward pass for base model.
642
+
643
+ Args:
644
+ input_ids: Input token IDs
645
+ attention_mask: Optional attention mask
646
+ position_ids: Optional position IDs
647
+ inputs_embeds: Optional input embeddings
648
+ output_hidden_states: Whether to return all hidden states
649
+ output_attentions: Whether to return attention weights
650
+
651
+ Returns:
652
+ Model outputs including hidden states and optionally attention weights
653
+ """
654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
  output_hidden_states = (
656
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
  )
658
+
659
  if input_ids is not None and inputs_embeds is not None:
660
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
661
  elif input_ids is not None:
 
673
  attention_mask=attention_mask,
674
  inputs_embeds=inputs_embeds,
675
  )
676
+
677
  if attention_mask is not None:
 
 
678
  extended_attention_mask = attention_mask[:, None, None, :].expand(
679
  batch_size, 1, seq_length, seq_length
680
  ).bool()
 
685
  embedding_output,
686
  attention_mask=extended_attention_mask,
687
  output_hidden_states=output_hidden_states,
688
+ output_attentions=output_attentions,
689
  )
690
  sequence_output = encoder_outputs.last_hidden_state
691
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
694
  last_hidden_state=sequence_output,
695
  pooler_output=pooled_output,
696
  hidden_states=encoder_outputs.hidden_states,
697
+ attentions=encoder_outputs.attentions,
698
  )
699
 
700
 
 
723
  labels: Optional[torch.LongTensor] = None,
724
  output_attentions: Optional[bool] = None,
725
  output_hidden_states: Optional[bool] = None,
726
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
727
  ) -> Union[Tuple, MaskedLMOutput]:
728
  outputs = self.esm(
729
  input_ids,
 
745
  loss=loss,
746
  logits=prediction_scores,
747
  hidden_states=outputs.hidden_states,
748
+ attentions=outputs.attentions,
749
  )
750
 
751
  def predict_contacts(self, tokens, attention_mask):
 
810
  loss=loss,
811
  logits=logits,
812
  hidden_states=outputs.hidden_states,
813
+ attentions=outputs.attentions,
814
  )
815
 
816
 
 
855
  loss=loss,
856
  logits=logits,
857
  hidden_states=outputs.hidden_states,
858
+ attentions=outputs.attentions,
859
  )
860
 
861