Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +8 -6
modeling_esm_plusplus.py
CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
468 |
"""
|
469 |
config_class = ESMplusplusConfig
|
470 |
-
def __init__(self, config: ESMplusplusConfig
|
471 |
-
super().__init__(config
|
472 |
self.config = config
|
473 |
self.vocab_size = config.vocab_size
|
474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
@@ -642,8 +642,9 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
642 |
|
643 |
Extends the base ESM++ model with a classification head.
|
644 |
"""
|
645 |
-
def __init__(self, config: ESMplusplusConfig,
|
646 |
-
|
|
|
647 |
self.config = config
|
648 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
649 |
# Large intermediate projections help with sequence classification tasks (*4)
|
@@ -714,8 +715,9 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
714 |
|
715 |
Extends the base ESM++ model with a token classification head.
|
716 |
"""
|
717 |
-
def __init__(self, config: ESMplusplusConfig,
|
718 |
-
|
|
|
719 |
self.config = config
|
720 |
self.num_labels = config.num_labels
|
721 |
self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
|
|
|
467 |
Implements the base ESM++ architecture with a masked language modeling head.
|
468 |
"""
|
469 |
config_class = ESMplusplusConfig
|
470 |
+
def __init__(self, config: ESMplusplusConfig):
|
471 |
+
super().__init__(config)
|
472 |
self.config = config
|
473 |
self.vocab_size = config.vocab_size
|
474 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
|
|
642 |
|
643 |
Extends the base ESM++ model with a classification head.
|
644 |
"""
|
645 |
+
def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
|
646 |
+
config.num_labels = num_labels
|
647 |
+
super().__init__(config)
|
648 |
self.config = config
|
649 |
self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
|
650 |
# Large intermediate projections help with sequence classification tasks (*4)
|
|
|
715 |
|
716 |
Extends the base ESM++ model with a token classification head.
|
717 |
"""
|
718 |
+
def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
|
719 |
+
config.num_labels = num_labels
|
720 |
+
super().__init__(config)
|
721 |
self.config = config
|
722 |
self.num_labels = config.num_labels
|
723 |
self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
|