yangwang825 commited on
Commit
d4e8ed4
1 Parent(s): 0b91f1b

Update modeling_wav2vec2_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_wav2vec2_spkreg.py +1 -1
modeling_wav2vec2_spkreg.py CHANGED
@@ -660,7 +660,7 @@ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
660
  self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
661
  elif self.config.loss_fct == 'additive_margin':
662
  self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
663
- elif self.config.loss_fct == 'additive_margin':
664
  self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
665
  else:
666
  raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
 
660
  self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
661
  elif self.config.loss_fct == 'additive_margin':
662
  self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
663
+ elif self.config.loss_fct == 'additive_angular_margin':
664
  self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
665
  else:
666
  raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")