lhallee commited on
Commit
1710ca2
·
verified ·
1 Parent(s): b4d86f4

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +10 -10
modeling_fastesm.py CHANGED
@@ -756,8 +756,8 @@ class FastEsmPreTrainedModel(PreTrainedModel):
756
 
757
 
758
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
759
- def __init__(self, config, add_pooling_layer: Optional[bool] = True):
760
- super(FastEsmPreTrainedModel, self).__init__(config)
761
  self.config = config
762
  self.embeddings = EsmEmbeddings(config)
763
  self.encoder = EsmEncoder(config)
@@ -864,8 +864,8 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
864
 
865
 
866
  class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
867
- def __init__(self, config, add_pooling_layer: Optional[bool] = True):
868
- super(FastEsmPreTrainedModel, self).__init__(config)
869
  self.config = config
870
  self.esm = FAST_ESM_ENCODER(config)
871
  self.pooler = EsmPooler(config) if add_pooling_layer else None
@@ -942,8 +942,8 @@ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
942
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
943
  _tied_weights_keys = ["lm_head.decoder.weight"]
944
 
945
- def __init__(self, config):
946
- super(FastEsmPreTrainedModel, self).__init__(config)
947
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
948
  self.lm_head = EsmLMHead(config)
949
  self.loss_fct = nn.CrossEntropyLoss()
@@ -998,8 +998,8 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
998
 
999
 
1000
  class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1001
- def __init__(self, config):
1002
- super(FastEsmPreTrainedModel, self).__init__(config)
1003
  self.num_labels = config.num_labels
1004
  self.config = config
1005
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
@@ -1067,8 +1067,8 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1067
 
1068
 
1069
  class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1070
- def __init__(self, config):
1071
- super(FastEsmPreTrainedModel, self).__init__(config)
1072
  self.num_labels = config.num_labels
1073
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
1074
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
 
756
 
757
 
758
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
759
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
760
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
761
  self.config = config
762
  self.embeddings = EsmEmbeddings(config)
763
  self.encoder = EsmEncoder(config)
 
864
 
865
 
866
  class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
867
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
868
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
869
  self.config = config
870
  self.esm = FAST_ESM_ENCODER(config)
871
  self.pooler = EsmPooler(config) if add_pooling_layer else None
 
942
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
943
  _tied_weights_keys = ["lm_head.decoder.weight"]
944
 
945
+ def __init__(self, config, **kwargs):
946
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
947
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
948
  self.lm_head = EsmLMHead(config)
949
  self.loss_fct = nn.CrossEntropyLoss()
 
998
 
999
 
1000
  class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1001
+ def __init__(self, config, **kwargs):
1002
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
1003
  self.num_labels = config.num_labels
1004
  self.config = config
1005
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
 
1067
 
1068
 
1069
  class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1070
+ def __init__(self, config, **kwargs):
1071
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
1072
  self.num_labels = config.num_labels
1073
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
1074
  self.dropout = nn.Dropout(config.hidden_dropout_prob)