Update modeling_fastesm.py
Browse files- modeling_fastesm.py +87 -4
modeling_fastesm.py
CHANGED
@@ -612,12 +612,95 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
612 |
|
613 |
return embeddings_dict
|
614 |
|
615 |
-
|
|
|
616 |
def __init__(self, config, add_pooling_layer=True):
|
617 |
super().__init__(config)
|
618 |
self.config = config
|
619 |
self.embeddings = EsmEmbeddings(config)
|
620 |
self.encoder = EsmEncoder(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
622 |
# Initialize weights and apply final processing
|
623 |
self.post_init()
|
@@ -703,7 +786,7 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
|
|
703 |
|
704 |
def __init__(self, config):
|
705 |
super().__init__(config)
|
706 |
-
self.esm =
|
707 |
self.lm_head = EsmLMHead(config)
|
708 |
self.loss_fct = nn.CrossEntropyLoss()
|
709 |
self.init_weights()
|
@@ -757,7 +840,7 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
|
|
757 |
super().__init__(config)
|
758 |
self.num_labels = config.num_labels
|
759 |
self.config = config
|
760 |
-
self.esm =
|
761 |
self.classifier = EsmClassificationHead(config)
|
762 |
self.mse = nn.MSELoss()
|
763 |
self.ce = nn.CrossEntropyLoss()
|
@@ -818,7 +901,7 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
|
|
818 |
def __init__(self, config):
|
819 |
super().__init__(config)
|
820 |
self.num_labels = config.num_labels
|
821 |
-
self.esm =
|
822 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
823 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
824 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
612 |
|
613 |
return embeddings_dict
|
614 |
|
615 |
+
|
616 |
+
class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
|
617 |
def __init__(self, config, add_pooling_layer=True):
|
618 |
super().__init__(config)
|
619 |
self.config = config
|
620 |
self.embeddings = EsmEmbeddings(config)
|
621 |
self.encoder = EsmEncoder(config)
|
622 |
+
# Initialize weights and apply final processing
|
623 |
+
self.post_init()
|
624 |
+
|
625 |
+
def get_input_embeddings(self):
|
626 |
+
return self.embeddings.word_embeddings
|
627 |
+
|
628 |
+
def set_input_embeddings(self, value):
|
629 |
+
self.embeddings.word_embeddings = value
|
630 |
+
|
631 |
+
def forward(
|
632 |
+
self,
|
633 |
+
input_ids: Optional[torch.LongTensor] = None,
|
634 |
+
attention_mask: Optional[torch.Tensor] = None,
|
635 |
+
position_ids: Optional[torch.LongTensor] = None,
|
636 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
637 |
+
output_attentions: Optional[bool] = None,
|
638 |
+
output_hidden_states: Optional[bool] = None,
|
639 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
640 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
641 |
+
"""Forward pass for base model.
|
642 |
+
|
643 |
+
Args:
|
644 |
+
input_ids: Input token IDs
|
645 |
+
attention_mask: Optional attention mask
|
646 |
+
position_ids: Optional position IDs
|
647 |
+
inputs_embeds: Optional input embeddings
|
648 |
+
output_hidden_states: Whether to return all hidden states
|
649 |
+
output_attentions: Whether to return attention weights
|
650 |
+
|
651 |
+
Returns:
|
652 |
+
Model outputs including hidden states and optionally attention weights
|
653 |
+
"""
|
654 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
655 |
+
output_hidden_states = (
|
656 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
657 |
+
)
|
658 |
+
|
659 |
+
if input_ids is not None and inputs_embeds is not None:
|
660 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
661 |
+
elif input_ids is not None:
|
662 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
663 |
+
input_shape = input_ids.size()
|
664 |
+
elif inputs_embeds is not None:
|
665 |
+
input_shape = inputs_embeds.size()[:-1]
|
666 |
+
else:
|
667 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
668 |
+
|
669 |
+
batch_size, seq_length = input_shape
|
670 |
+
embedding_output = self.embeddings(
|
671 |
+
input_ids=input_ids,
|
672 |
+
position_ids=position_ids,
|
673 |
+
attention_mask=attention_mask,
|
674 |
+
inputs_embeds=inputs_embeds,
|
675 |
+
)
|
676 |
+
|
677 |
+
if attention_mask is not None:
|
678 |
+
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
679 |
+
batch_size, 1, seq_length, seq_length
|
680 |
+
).bool()
|
681 |
+
else:
|
682 |
+
extended_attention_mask = None
|
683 |
+
|
684 |
+
encoder_outputs = self.encoder(
|
685 |
+
embedding_output,
|
686 |
+
attention_mask=extended_attention_mask,
|
687 |
+
output_hidden_states=output_hidden_states,
|
688 |
+
output_attentions=output_attentions,
|
689 |
+
)
|
690 |
+
sequence_output = encoder_outputs.last_hidden_state
|
691 |
+
|
692 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
693 |
+
last_hidden_state=sequence_output,
|
694 |
+
hidden_states=encoder_outputs.hidden_states,
|
695 |
+
attentions=encoder_outputs.attentions,
|
696 |
+
)
|
697 |
+
|
698 |
+
|
699 |
+
class FastEsmModel(FastEsmPreTrainedModel):
|
700 |
+
def __init__(self, config, add_pooling_layer=True):
|
701 |
+
super().__init__(config)
|
702 |
+
self.config = config
|
703 |
+
self.esm = FAST_ESM_ENCODER(config)
|
704 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
705 |
# Initialize weights and apply final processing
|
706 |
self.post_init()
|
|
|
786 |
|
787 |
def __init__(self, config):
|
788 |
super().__init__(config)
|
789 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
790 |
self.lm_head = EsmLMHead(config)
|
791 |
self.loss_fct = nn.CrossEntropyLoss()
|
792 |
self.init_weights()
|
|
|
840 |
super().__init__(config)
|
841 |
self.num_labels = config.num_labels
|
842 |
self.config = config
|
843 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
844 |
self.classifier = EsmClassificationHead(config)
|
845 |
self.mse = nn.MSELoss()
|
846 |
self.ce = nn.CrossEntropyLoss()
|
|
|
901 |
def __init__(self, config):
|
902 |
super().__init__(config)
|
903 |
self.num_labels = config.num_labels
|
904 |
+
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
905 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
906 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
907 |
self.loss_fct = nn.CrossEntropyLoss()
|