lhallee commited on
Commit
eb0e5cb
·
verified ·
1 Parent(s): 3e3b4de

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +6 -7
modeling_esm_plusplus.py CHANGED
@@ -467,8 +467,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
- def __init__(self, config: ESMplusplusConfig):
471
- super().__init__(config)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -642,10 +642,10 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
- def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
646
- config.num_labels = num_labels
647
- super().__init__(config)
648
  self.config = config
 
649
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
650
  # Large intermediate projections help with sequence classification tasks (*4)
651
  self.mse = nn.MSELoss()
@@ -715,8 +715,7 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
715
 
716
  Extends the base ESM++ model with a token classification head.
717
  """
718
- def __init__(self, config: ESMplusplusConfig, num_labels: int=2):
719
- config.num_labels = num_labels
720
  super().__init__(config)
721
  self.config = config
722
  self.num_labels = config.num_labels
 
467
  Implements the base ESM++ architecture with a masked language modeling head.
468
  """
469
  config_class = ESMplusplusConfig
470
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
471
+ super().__init__(config, **kwargs)
472
  self.config = config
473
  self.vocab_size = config.vocab_size
474
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
642
 
643
  Extends the base ESM++ model with a classification head.
644
  """
645
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
646
+ super().__init__(config, **kwargs)
 
647
  self.config = config
648
+ self.num_labels = config.num_labels
649
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
650
  # Large intermediate projections help with sequence classification tasks (*4)
651
  self.mse = nn.MSELoss()
 
715
 
716
  Extends the base ESM++ model with a token classification head.
717
  """
718
+ def __init__(self, config: ESMplusplusConfig):
 
719
  super().__init__(config)
720
  self.config = config
721
  self.num_labels = config.num_labels