Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +18 -2
modeling_fastesm.py
CHANGED
@@ -53,6 +53,7 @@ class FastEsmConfig(PretrainedConfig):
|
|
53 |
layer_norm_eps: float = 1e-12,
|
54 |
position_embedding_type: str = "absolute",
|
55 |
emb_layer_norm_before: bool = None,
|
|
|
56 |
**kwargs,
|
57 |
):
|
58 |
super().__init__(
|
@@ -74,6 +75,7 @@ class FastEsmConfig(PretrainedConfig):
|
|
74 |
self.position_embedding_type = position_embedding_type
|
75 |
self.emb_layer_norm_before = emb_layer_norm_before
|
76 |
self.tie_word_embeddings = False
|
|
|
77 |
|
78 |
def to_dict(self) -> Dict[str, Any]:
|
79 |
"""
|
@@ -209,6 +211,8 @@ class EsmEmbeddings(nn.Module):
|
|
209 |
self.register_buffer(
|
210 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
211 |
)
|
|
|
|
|
212 |
|
213 |
def forward(
|
214 |
self,
|
@@ -223,6 +227,18 @@ class EsmEmbeddings(nn.Module):
|
|
223 |
|
224 |
embeddings = inputs_embeds
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
if self.layer_norm is not None:
|
227 |
embeddings = self.layer_norm(embeddings)
|
228 |
if attention_mask is not None:
|
@@ -300,8 +316,8 @@ class EsmSelfAttention(nn.Module):
|
|
300 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
301 |
|
302 |
if output_attentions:
|
303 |
-
# Manual attention computation
|
304 |
-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
305 |
if attention_mask is not None:
|
306 |
attention_scores = attention_scores + attention_mask
|
307 |
attention_probs = F.softmax(attention_scores, dim=-1)
|
|
|
53 |
layer_norm_eps: float = 1e-12,
|
54 |
position_embedding_type: str = "absolute",
|
55 |
emb_layer_norm_before: bool = None,
|
56 |
+
token_dropout: bool = True,
|
57 |
**kwargs,
|
58 |
):
|
59 |
super().__init__(
|
|
|
75 |
self.position_embedding_type = position_embedding_type
|
76 |
self.emb_layer_norm_before = emb_layer_norm_before
|
77 |
self.tie_word_embeddings = False
|
78 |
+
self.token_dropout = token_dropout
|
79 |
|
80 |
def to_dict(self) -> Dict[str, Any]:
|
81 |
"""
|
|
|
211 |
self.register_buffer(
|
212 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
213 |
)
|
214 |
+
self.token_dropout = config.token_dropout
|
215 |
+
self.mask_token_id = config.mask_token_id
|
216 |
|
217 |
def forward(
|
218 |
self,
|
|
|
227 |
|
228 |
embeddings = inputs_embeds
|
229 |
|
230 |
+
if attention_mask is None:
|
231 |
+
attention_mask = torch.ones_like(input_ids)
|
232 |
+
|
233 |
+
if self.token_dropout:
|
234 |
+
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
|
235 |
+
mask_ratio_train = 0.15 * 0.8
|
236 |
+
src_lengths = attention_mask.sum(-1)
|
237 |
+
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
|
238 |
+
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
|
239 |
+
embeddings.dtype
|
240 |
+
)
|
241 |
+
|
242 |
if self.layer_norm is not None:
|
243 |
embeddings = self.layer_norm(embeddings)
|
244 |
if attention_mask is not None:
|
|
|
316 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
317 |
|
318 |
if output_attentions:
|
319 |
+
# Manual attention computation - apply scaling here
|
320 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) * self.scale
|
321 |
if attention_mask is not None:
|
322 |
attention_scores = attention_scores + attention_mask
|
323 |
attention_probs = F.softmax(attention_scores, dim=-1)
|