lhallee commited on
Commit
7288587
·
verified ·
1 Parent(s): 69c5916

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +111 -7
modeling_fastesm.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  from typing import Optional, Tuple, Union
5
  from einops import rearrange
6
- from transformers import PreTrainedModel
7
  from transformers.modeling_outputs import (
8
  MaskedLMOutput,
9
  BaseModelOutputWithPastAndCrossAttentions,
@@ -12,8 +12,6 @@ from transformers.modeling_outputs import (
12
  TokenClassifierOutput
13
  )
14
  from transformers.models.esm.modeling_esm import (
15
- RotaryEmbedding,
16
- EsmContactPredictionHead,
17
  EsmIntermediate,
18
  EsmOutput,
19
  EsmPooler,
@@ -22,7 +20,108 @@ from transformers.models.esm.modeling_esm import (
22
  EsmClassificationHead,
23
  create_position_ids_from_input_ids,
24
  )
25
- from .config_fastesm import FastEsmConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class EsmEmbeddings(nn.Module):
@@ -134,6 +233,10 @@ class EsmSelfAttention(nn.Module):
134
  if self.position_embedding_type == "rotary":
135
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
136
 
 
 
 
 
137
  context_layer = F.scaled_dot_product_attention(
138
  query_layer,
139
  key_layer,
@@ -501,7 +604,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
501
  if __name__ == "__main__":
502
  """
503
  Test the hidden state differences between the FastEsmModel and the HF EsmModel.
504
- In full precision, the differences are very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
505
  In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
506
  """
507
  import random
@@ -526,8 +629,9 @@ if __name__ == "__main__":
526
  for model_path in model_paths:
527
  print(f"Testing {model_path}...")
528
  tokenizer = EsmTokenizer.from_pretrained(model_path)
529
- fast_model = FastEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
530
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
 
531
 
532
  counts = [0] * len(tolerances)
533
  for _ in range(seq_count):
 
3
  from torch.nn import functional as F
4
  from typing import Optional, Tuple, Union
5
  from einops import rearrange
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
  from transformers.modeling_outputs import (
8
  MaskedLMOutput,
9
  BaseModelOutputWithPastAndCrossAttentions,
 
12
  TokenClassifierOutput
13
  )
14
  from transformers.models.esm.modeling_esm import (
 
 
15
  EsmIntermediate,
16
  EsmOutput,
17
  EsmPooler,
 
20
  EsmClassificationHead,
21
  create_position_ids_from_input_ids,
22
  )
23
+
24
+
25
+ class FastEsmConfig(PretrainedConfig):
26
+ model_type = "fast_esm"
27
+
28
+ def __init__(
29
+ self,
30
+ vocab_size=None,
31
+ mask_token_id=None,
32
+ pad_token_id=None,
33
+ hidden_size=768,
34
+ num_hidden_layers=12,
35
+ num_attention_heads=12,
36
+ intermediate_size=3072,
37
+ hidden_dropout_prob=0.1,
38
+ attention_probs_dropout_prob=0.1,
39
+ max_position_embeddings=1026,
40
+ initializer_range=0.02,
41
+ layer_norm_eps=1e-12,
42
+ position_embedding_type="absolute",
43
+ emb_layer_norm_before=None,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
47
+
48
+ self.vocab_size = vocab_size
49
+ self.hidden_size = hidden_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.num_attention_heads = num_attention_heads
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_dropout_prob = hidden_dropout_prob
54
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.initializer_range = initializer_range
57
+ self.layer_norm_eps = layer_norm_eps
58
+ self.position_embedding_type = position_embedding_type
59
+ self.emb_layer_norm_before = emb_layer_norm_before
60
+
61
+ def to_dict(self):
62
+ """
63
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
64
+
65
+ Returns:
66
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
67
+ """
68
+ output = super().to_dict()
69
+ return output
70
+
71
+
72
+ def rotate_half(x):
73
+ x1, x2 = x.chunk(2, dim=-1)
74
+ return torch.cat((-x2, x1), dim=-1)
75
+
76
+
77
+ def apply_rotary_pos_emb(x, cos, sin):
78
+ cos = cos[:, :, : x.shape[-2], :]
79
+ sin = sin[:, :, : x.shape[-2], :]
80
+
81
+ return (x * cos) + (rotate_half(x) * sin)
82
+
83
+
84
+ class RotaryEmbedding(torch.nn.Module):
85
+ """
86
+ Rotary position embeddings based on those in
87
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
88
+ matrices which depend on their relative positions.
89
+ """
90
+
91
+ def __init__(self, dim: int):
92
+ super().__init__()
93
+ # Generate and save the inverse frequency buffer (non trainable)
94
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
95
+ inv_freq = inv_freq
96
+ self.register_buffer("inv_freq", inv_freq)
97
+
98
+ self._seq_len_cached = None
99
+ self._cos_cached = None
100
+ self._sin_cached = None
101
+
102
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
103
+ seq_len = x.shape[seq_dimension]
104
+
105
+ # Reset the tables if the sequence length has changed,
106
+ # or if we're on a new device (possibly due to tracing for instance)
107
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
108
+ self._seq_len_cached = seq_len
109
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
110
+ freqs = torch.outer(t, self.inv_freq)
111
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
112
+
113
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
114
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
115
+
116
+ return self._cos_cached, self._sin_cached
117
+
118
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
120
+
121
+ return (
122
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
123
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
124
+ )
125
 
126
 
127
  class EsmEmbeddings(nn.Module):
 
233
  if self.position_embedding_type == "rotary":
234
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
235
 
236
+ # Ensure all tensors have the same dtype before calling scaled_dot_product_attention
237
+ #query_layer = query_layer.to(value_layer.dtype)
238
+ #key_layer = key_layer.to(value_layer.dtype)
239
+
240
  context_layer = F.scaled_dot_product_attention(
241
  query_layer,
242
  key_layer,
 
604
  if __name__ == "__main__":
605
  """
606
  Test the hidden state differences between the FastEsmModel and the HF EsmModel.
607
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
608
  In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
609
  """
610
  import random
 
629
  for model_path in model_paths:
630
  print(f"Testing {model_path}...")
631
  tokenizer = EsmTokenizer.from_pretrained(model_path)
632
+ config = FastEsmConfig.from_pretrained(model_path)
633
+ fast_model = FastEsmModel(config).from_pretrained(model_path, torch_dtype=torch.float16).to(device)
634
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False, torch_dtype=torch.float16).to(device)
635
 
636
  counts = [0] * len(tolerances)
637
  for _ in range(seq_count):