lhallee commited on
Commit
3e3b4de
·
verified ·
1 Parent(s): 5c163ed

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +8 -6
modeling_esm_plusplus.py CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
- def __init__(self, config: ESMplusplusConfig, **kwargs):
471
- super().__init__(config, **kwargs)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -642,8 +642,9 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
- def __init__(self, config: ESMplusplusConfig, **kwargs):
646
- super().__init__(config, **kwargs)
 
647
  self.config = config
648
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
649
  # Large intermediate projections help with sequence classification tasks (*4)
@@ -714,8 +715,9 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
714
 
715
  Extends the base ESM++ model with a token classification head.
716
  """
717
- def __init__(self, config: ESMplusplusConfig, **kwargs):
718
- super().__init__(config, **kwargs)
 
719
  self.config = config
720
  self.num_labels = config.num_labels
721
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
 
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
+ def __init__(self, config: ESMplusplusConfig):
471
+ super().__init__(config)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
+ def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
646
+ config.num_labels = num_labels
647
+ super().__init__(config)
648
  self.config = config
649
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
650
  # Large intermediate projections help with sequence classification tasks (*4)
 
715
 
716
  Extends the base ESM++ model with a token classification head.
717
  """
718
+ def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
719
+ config.num_labels = num_labels
720
+ super().__init__(config)
721
  self.config = config
722
  self.num_labels = config.num_labels
723
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)