lhallee commited on
Commit
2354b09
·
verified ·
1 Parent(s): 7bbf228

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +18 -2
modeling_fastesm.py CHANGED
@@ -53,6 +53,7 @@ class FastEsmConfig(PretrainedConfig):
53
  layer_norm_eps: float = 1e-12,
54
  position_embedding_type: str = "absolute",
55
  emb_layer_norm_before: bool = None,
 
56
  **kwargs,
57
  ):
58
  super().__init__(
@@ -74,6 +75,7 @@ class FastEsmConfig(PretrainedConfig):
74
  self.position_embedding_type = position_embedding_type
75
  self.emb_layer_norm_before = emb_layer_norm_before
76
  self.tie_word_embeddings = False
 
77
 
78
  def to_dict(self) -> Dict[str, Any]:
79
  """
@@ -209,6 +211,8 @@ class EsmEmbeddings(nn.Module):
209
  self.register_buffer(
210
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
211
  )
 
 
212
 
213
  def forward(
214
  self,
@@ -223,6 +227,18 @@ class EsmEmbeddings(nn.Module):
223
 
224
  embeddings = inputs_embeds
225
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if self.layer_norm is not None:
227
  embeddings = self.layer_norm(embeddings)
228
  if attention_mask is not None:
@@ -300,8 +316,8 @@ class EsmSelfAttention(nn.Module):
300
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
301
 
302
  if output_attentions:
303
- # Manual attention computation to get attention weights
304
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
305
  if attention_mask is not None:
306
  attention_scores = attention_scores + attention_mask
307
  attention_probs = F.softmax(attention_scores, dim=-1)
 
53
  layer_norm_eps: float = 1e-12,
54
  position_embedding_type: str = "absolute",
55
  emb_layer_norm_before: bool = None,
56
+ token_dropout: bool = True,
57
  **kwargs,
58
  ):
59
  super().__init__(
 
75
  self.position_embedding_type = position_embedding_type
76
  self.emb_layer_norm_before = emb_layer_norm_before
77
  self.tie_word_embeddings = False
78
+ self.token_dropout = token_dropout
79
 
80
  def to_dict(self) -> Dict[str, Any]:
81
  """
 
211
  self.register_buffer(
212
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
213
  )
214
+ self.token_dropout = config.token_dropout
215
+ self.mask_token_id = config.mask_token_id
216
 
217
  def forward(
218
  self,
 
227
 
228
  embeddings = inputs_embeds
229
 
230
+ if attention_mask is None:
231
+ attention_mask = torch.ones_like(input_ids)
232
+
233
+ if self.token_dropout:
234
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
235
+ mask_ratio_train = 0.15 * 0.8
236
+ src_lengths = attention_mask.sum(-1)
237
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
238
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
239
+ embeddings.dtype
240
+ )
241
+
242
  if self.layer_norm is not None:
243
  embeddings = self.layer_norm(embeddings)
244
  if attention_mask is not None:
 
316
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
317
 
318
  if output_attentions:
319
+ # Manual attention computation - apply scaling here
320
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) * self.scale
321
  if attention_mask is not None:
322
  attention_scores = attention_scores + attention_mask
323
  attention_probs = F.softmax(attention_scores, dim=-1)