lhallee commited on
Commit
45ab4a0
·
verified ·
1 Parent(s): c98ee00

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +87 -4
modeling_fastesm.py CHANGED
@@ -612,12 +612,95 @@ class FastEsmPreTrainedModel(PreTrainedModel):
612
 
613
  return embeddings_dict
614
 
615
- class FastEsmModel(FastEsmPreTrainedModel):
 
616
  def __init__(self, config, add_pooling_layer=True):
617
  super().__init__(config)
618
  self.config = config
619
  self.embeddings = EsmEmbeddings(config)
620
  self.encoder = EsmEncoder(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  self.pooler = EsmPooler(config) if add_pooling_layer else None
622
  # Initialize weights and apply final processing
623
  self.post_init()
@@ -703,7 +786,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
703
 
704
  def __init__(self, config):
705
  super().__init__(config)
706
- self.esm = FastEsmModel(config, add_pooling_layer=False)
707
  self.lm_head = EsmLMHead(config)
708
  self.loss_fct = nn.CrossEntropyLoss()
709
  self.init_weights()
@@ -757,7 +840,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
757
  super().__init__(config)
758
  self.num_labels = config.num_labels
759
  self.config = config
760
- self.esm = FastEsmModel(config, add_pooling_layer=False)
761
  self.classifier = EsmClassificationHead(config)
762
  self.mse = nn.MSELoss()
763
  self.ce = nn.CrossEntropyLoss()
@@ -818,7 +901,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
818
  def __init__(self, config):
819
  super().__init__(config)
820
  self.num_labels = config.num_labels
821
- self.esm = FastEsmModel(config, add_pooling_layer=False)
822
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
823
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
824
  self.loss_fct = nn.CrossEntropyLoss()
 
612
 
613
  return embeddings_dict
614
 
615
+
616
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
617
  def __init__(self, config, add_pooling_layer=True):
618
  super().__init__(config)
619
  self.config = config
620
  self.embeddings = EsmEmbeddings(config)
621
  self.encoder = EsmEncoder(config)
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self):
626
+ return self.embeddings.word_embeddings
627
+
628
+ def set_input_embeddings(self, value):
629
+ self.embeddings.word_embeddings = value
630
+
631
+ def forward(
632
+ self,
633
+ input_ids: Optional[torch.LongTensor] = None,
634
+ attention_mask: Optional[torch.Tensor] = None,
635
+ position_ids: Optional[torch.LongTensor] = None,
636
+ inputs_embeds: Optional[torch.FloatTensor] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
640
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
641
+ """Forward pass for base model.
642
+
643
+ Args:
644
+ input_ids: Input token IDs
645
+ attention_mask: Optional attention mask
646
+ position_ids: Optional position IDs
647
+ inputs_embeds: Optional input embeddings
648
+ output_hidden_states: Whether to return all hidden states
649
+ output_attentions: Whether to return attention weights
650
+
651
+ Returns:
652
+ Model outputs including hidden states and optionally attention weights
653
+ """
654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
+ output_hidden_states = (
656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
+ )
658
+
659
+ if input_ids is not None and inputs_embeds is not None:
660
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
661
+ elif input_ids is not None:
662
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
663
+ input_shape = input_ids.size()
664
+ elif inputs_embeds is not None:
665
+ input_shape = inputs_embeds.size()[:-1]
666
+ else:
667
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
668
+
669
+ batch_size, seq_length = input_shape
670
+ embedding_output = self.embeddings(
671
+ input_ids=input_ids,
672
+ position_ids=position_ids,
673
+ attention_mask=attention_mask,
674
+ inputs_embeds=inputs_embeds,
675
+ )
676
+
677
+ if attention_mask is not None:
678
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
679
+ batch_size, 1, seq_length, seq_length
680
+ ).bool()
681
+ else:
682
+ extended_attention_mask = None
683
+
684
+ encoder_outputs = self.encoder(
685
+ embedding_output,
686
+ attention_mask=extended_attention_mask,
687
+ output_hidden_states=output_hidden_states,
688
+ output_attentions=output_attentions,
689
+ )
690
+ sequence_output = encoder_outputs.last_hidden_state
691
+
692
+ return BaseModelOutputWithPoolingAndCrossAttentions(
693
+ last_hidden_state=sequence_output,
694
+ hidden_states=encoder_outputs.hidden_states,
695
+ attentions=encoder_outputs.attentions,
696
+ )
697
+
698
+
699
+ class FastEsmModel(FastEsmPreTrainedModel):
700
+ def __init__(self, config, add_pooling_layer=True):
701
+ super().__init__(config)
702
+ self.config = config
703
+ self.esm = FAST_ESM_ENCODER(config)
704
  self.pooler = EsmPooler(config) if add_pooling_layer else None
705
  # Initialize weights and apply final processing
706
  self.post_init()
 
786
 
787
  def __init__(self, config):
788
  super().__init__(config)
789
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
790
  self.lm_head = EsmLMHead(config)
791
  self.loss_fct = nn.CrossEntropyLoss()
792
  self.init_weights()
 
840
  super().__init__(config)
841
  self.num_labels = config.num_labels
842
  self.config = config
843
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
844
  self.classifier = EsmClassificationHead(config)
845
  self.mse = nn.MSELoss()
846
  self.ce = nn.CrossEntropyLoss()
 
901
  def __init__(self, config):
902
  super().__init__(config)
903
  self.num_labels = config.num_labels
904
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
905
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
906
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
907
  self.loss_fct = nn.CrossEntropyLoss()