lhallee commited on
Commit
874ce57
·
verified ·
1 Parent(s): c5e15fd

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +8 -15
modeling_fastesm.py CHANGED
@@ -233,10 +233,6 @@ class EsmSelfAttention(nn.Module):
233
  if self.position_embedding_type == "rotary":
234
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
235
 
236
- # Ensure all tensors have the same dtype before calling scaled_dot_product_attention
237
- #query_layer = query_layer.to(value_layer.dtype)
238
- #key_layer = key_layer.to(value_layer.dtype)
239
-
240
  context_layer = F.scaled_dot_product_attention(
241
  query_layer,
242
  key_layer,
@@ -422,10 +418,7 @@ class FastEsmModel(FastEsmPreTrainedModel):
422
  # Expand to (batch_size, 1, seq_length, seq_length)
423
  extended_attention_mask = attention_mask[:, None, None, :].expand(
424
  batch_size, 1, seq_length, seq_length
425
- )
426
- # Convert mask to float with 0.0 for positions to keep and -inf for masked positions
427
- attention_mask = attention_mask.to(dtype=embedding_output.dtype) # fp16 compatibility
428
- attention_mask = (1.0 - attention_mask) * torch.finfo(embedding_output.dtype).min
429
  else:
430
  extended_attention_mask = None
431
 
@@ -608,13 +601,13 @@ if __name__ == "__main__":
608
  In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
609
  """
610
  import random
611
- from transformers import EsmModel as TransformersEsmModel, EsmTokenizer
612
 
613
  model_paths = [
614
  "facebook/esm2_t6_8M_UR50D",
615
  "facebook/esm2_t12_35M_UR50D",
616
- "facebook/esm2_t30_150M_UR50D",
617
- "facebook/esm2_t33_650M_UR50D",
618
  ]
619
  canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
620
  length = 64
@@ -630,17 +623,17 @@ if __name__ == "__main__":
630
  print(f"Testing {model_path}...")
631
  tokenizer = EsmTokenizer.from_pretrained(model_path)
632
  config = FastEsmConfig.from_pretrained(model_path)
633
- fast_model = FastEsmModel(config).from_pretrained(model_path, torch_dtype=torch.float16).to(device)
634
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False, torch_dtype=torch.float16).to(device)
635
 
636
  counts = [0] * len(tolerances)
637
  for _ in range(seq_count):
638
  example_seq = generate_random_sequence(length)
639
  fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
640
- fast_output = fast_model(fast_tokens).last_hidden_state.detach().cpu()
641
 
642
  model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
643
- model_output = model(model_tokens).last_hidden_state.detach().cpu()
644
 
645
  for i, atol in enumerate(tolerances):
646
  if torch.allclose(fast_output, model_output, atol=atol):
 
233
  if self.position_embedding_type == "rotary":
234
  query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
235
 
 
 
 
 
236
  context_layer = F.scaled_dot_product_attention(
237
  query_layer,
238
  key_layer,
 
418
  # Expand to (batch_size, 1, seq_length, seq_length)
419
  extended_attention_mask = attention_mask[:, None, None, :].expand(
420
  batch_size, 1, seq_length, seq_length
421
+ ).bool()
 
 
 
422
  else:
423
  extended_attention_mask = None
424
 
 
601
  In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
602
  """
603
  import random
604
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
605
 
606
  model_paths = [
607
  "facebook/esm2_t6_8M_UR50D",
608
  "facebook/esm2_t12_35M_UR50D",
609
+ #"facebook/esm2_t30_150M_UR50D",
610
+ #"facebook/esm2_t33_650M_UR50D",
611
  ]
612
  canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
613
  length = 64
 
623
  print(f"Testing {model_path}...")
624
  tokenizer = EsmTokenizer.from_pretrained(model_path)
625
  config = FastEsmConfig.from_pretrained(model_path)
626
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
627
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
628
 
629
  counts = [0] * len(tolerances)
630
  for _ in range(seq_count):
631
  example_seq = generate_random_sequence(length)
632
  fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
633
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
634
 
635
  model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
636
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
637
 
638
  for i, atol in enumerate(tolerances):
639
  if torch.allclose(fast_output, model_output, atol=atol):