yangwang825
commited on
Commit
•
d4e8ed4
1
Parent(s):
0b91f1b
Update modeling_wav2vec2_spkreg.py
Browse files
modeling_wav2vec2_spkreg.py
CHANGED
@@ -660,7 +660,7 @@ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
|
|
660 |
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
661 |
elif self.config.loss_fct == 'additive_margin':
|
662 |
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
|
663 |
-
elif self.config.loss_fct == '
|
664 |
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
|
665 |
else:
|
666 |
raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
|
|
|
660 |
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
661 |
elif self.config.loss_fct == 'additive_margin':
|
662 |
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
|
663 |
+
elif self.config.loss_fct == 'additive_angular_margin':
|
664 |
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
|
665 |
else:
|
666 |
raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
|