Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +10 -10
modeling_fastesm.py
CHANGED
@@ -756,8 +756,8 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
756 |
|
757 |
|
758 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
759 |
-
def __init__(self, config, add_pooling_layer: Optional[bool] = True):
|
760 |
-
|
761 |
self.config = config
|
762 |
self.embeddings = EsmEmbeddings(config)
|
763 |
self.encoder = EsmEncoder(config)
|
@@ -864,8 +864,8 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
864 |
|
865 |
|
866 |
class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
867 |
-
def __init__(self, config, add_pooling_layer: Optional[bool] = True):
|
868 |
-
|
869 |
self.config = config
|
870 |
self.esm = FAST_ESM_ENCODER(config)
|
871 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
@@ -942,8 +942,8 @@ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
942 |
class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
943 |
_tied_weights_keys = ["lm_head.decoder.weight"]
|
944 |
|
945 |
-
def __init__(self, config):
|
946 |
-
|
947 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
948 |
self.lm_head = EsmLMHead(config)
|
949 |
self.loss_fct = nn.CrossEntropyLoss()
|
@@ -998,8 +998,8 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
998 |
|
999 |
|
1000 |
class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
1001 |
-
def __init__(self, config):
|
1002 |
-
|
1003 |
self.num_labels = config.num_labels
|
1004 |
self.config = config
|
1005 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
@@ -1067,8 +1067,8 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
1067 |
|
1068 |
|
1069 |
class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
1070 |
-
def __init__(self, config):
|
1071 |
-
|
1072 |
self.num_labels = config.num_labels
|
1073 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
1074 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
756 |
|
757 |
|
758 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
759 |
+
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|
760 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
761 |
self.config = config
|
762 |
self.embeddings = EsmEmbeddings(config)
|
763 |
self.encoder = EsmEncoder(config)
|
|
|
864 |
|
865 |
|
866 |
class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
867 |
+
def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
|
868 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
869 |
self.config = config
|
870 |
self.esm = FAST_ESM_ENCODER(config)
|
871 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
|
942 |
class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
943 |
_tied_weights_keys = ["lm_head.decoder.weight"]
|
944 |
|
945 |
+
def __init__(self, config, **kwargs):
|
946 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
947 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
948 |
self.lm_head = EsmLMHead(config)
|
949 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
998 |
|
999 |
|
1000 |
class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
1001 |
+
def __init__(self, config, **kwargs):
|
1002 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
1003 |
self.num_labels = config.num_labels
|
1004 |
self.config = config
|
1005 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
|
|
1067 |
|
1068 |
|
1069 |
class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
1070 |
+
def __init__(self, config, **kwargs):
|
1071 |
+
FastEsmPreTrainedModel.__init__(self, config, **kwargs)
|
1072 |
self.num_labels = config.num_labels
|
1073 |
self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
|
1074 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|