Update modeling_fastesm.py
Browse files- modeling_fastesm.py +8 -15
modeling_fastesm.py
CHANGED
@@ -233,10 +233,6 @@ class EsmSelfAttention(nn.Module):
|
|
233 |
if self.position_embedding_type == "rotary":
|
234 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
235 |
|
236 |
-
# Ensure all tensors have the same dtype before calling scaled_dot_product_attention
|
237 |
-
#query_layer = query_layer.to(value_layer.dtype)
|
238 |
-
#key_layer = key_layer.to(value_layer.dtype)
|
239 |
-
|
240 |
context_layer = F.scaled_dot_product_attention(
|
241 |
query_layer,
|
242 |
key_layer,
|
@@ -422,10 +418,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
422 |
# Expand to (batch_size, 1, seq_length, seq_length)
|
423 |
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
424 |
batch_size, 1, seq_length, seq_length
|
425 |
-
)
|
426 |
-
# Convert mask to float with 0.0 for positions to keep and -inf for masked positions
|
427 |
-
attention_mask = attention_mask.to(dtype=embedding_output.dtype) # fp16 compatibility
|
428 |
-
attention_mask = (1.0 - attention_mask) * torch.finfo(embedding_output.dtype).min
|
429 |
else:
|
430 |
extended_attention_mask = None
|
431 |
|
@@ -608,13 +601,13 @@ if __name__ == "__main__":
|
|
608 |
In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
|
609 |
"""
|
610 |
import random
|
611 |
-
from transformers import
|
612 |
|
613 |
model_paths = [
|
614 |
"facebook/esm2_t6_8M_UR50D",
|
615 |
"facebook/esm2_t12_35M_UR50D",
|
616 |
-
"facebook/esm2_t30_150M_UR50D",
|
617 |
-
"facebook/esm2_t33_650M_UR50D",
|
618 |
]
|
619 |
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
620 |
length = 64
|
@@ -630,17 +623,17 @@ if __name__ == "__main__":
|
|
630 |
print(f"Testing {model_path}...")
|
631 |
tokenizer = EsmTokenizer.from_pretrained(model_path)
|
632 |
config = FastEsmConfig.from_pretrained(model_path)
|
633 |
-
fast_model =
|
634 |
-
model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False
|
635 |
|
636 |
counts = [0] * len(tolerances)
|
637 |
for _ in range(seq_count):
|
638 |
example_seq = generate_random_sequence(length)
|
639 |
fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
640 |
-
fast_output = fast_model(fast_tokens).
|
641 |
|
642 |
model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
643 |
-
model_output = model(model_tokens).
|
644 |
|
645 |
for i, atol in enumerate(tolerances):
|
646 |
if torch.allclose(fast_output, model_output, atol=atol):
|
|
|
233 |
if self.position_embedding_type == "rotary":
|
234 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
235 |
|
|
|
|
|
|
|
|
|
236 |
context_layer = F.scaled_dot_product_attention(
|
237 |
query_layer,
|
238 |
key_layer,
|
|
|
418 |
# Expand to (batch_size, 1, seq_length, seq_length)
|
419 |
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
420 |
batch_size, 1, seq_length, seq_length
|
421 |
+
).bool()
|
|
|
|
|
|
|
422 |
else:
|
423 |
extended_attention_mask = None
|
424 |
|
|
|
601 |
In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
|
602 |
"""
|
603 |
import random
|
604 |
+
from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
|
605 |
|
606 |
model_paths = [
|
607 |
"facebook/esm2_t6_8M_UR50D",
|
608 |
"facebook/esm2_t12_35M_UR50D",
|
609 |
+
#"facebook/esm2_t30_150M_UR50D",
|
610 |
+
#"facebook/esm2_t33_650M_UR50D",
|
611 |
]
|
612 |
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
613 |
length = 64
|
|
|
623 |
print(f"Testing {model_path}...")
|
624 |
tokenizer = EsmTokenizer.from_pretrained(model_path)
|
625 |
config = FastEsmConfig.from_pretrained(model_path)
|
626 |
+
fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
|
627 |
+
model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
|
628 |
|
629 |
counts = [0] * len(tolerances)
|
630 |
for _ in range(seq_count):
|
631 |
example_seq = generate_random_sequence(length)
|
632 |
fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
633 |
+
fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
|
634 |
|
635 |
model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
636 |
+
model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
|
637 |
|
638 |
for i, atol in enumerate(tolerances):
|
639 |
if torch.allclose(fast_output, model_output, atol=atol):
|