yangwang825
commited on
Commit
•
0b91f1b
1
Parent(s):
f40a8ec
Update modeling_wav2vec2_spkreg.py
Browse files- modeling_wav2vec2_spkreg.py +26 -9
modeling_wav2vec2_spkreg.py
CHANGED
@@ -529,6 +529,8 @@ class AMSoftmaxLoss(nn.Module):
|
|
529 |
num_labels: int,
|
530 |
scale: float = 30.0,
|
531 |
margin: float = 0.35,
|
|
|
|
|
532 |
):
|
533 |
"""
|
534 |
Args:
|
@@ -540,13 +542,13 @@ class AMSoftmaxLoss(nn.Module):
|
|
540 |
self.num_labels = num_labels
|
541 |
self.scale = scale
|
542 |
self.margin = margin
|
|
|
|
|
543 |
|
544 |
def forward(
|
545 |
self,
|
546 |
inputs: torch.Tensor,
|
547 |
targets: torch.Tensor,
|
548 |
-
label_smoothing: float = 0.0,
|
549 |
-
reduction: str = "mean"
|
550 |
):
|
551 |
"""
|
552 |
Args:
|
@@ -562,7 +564,9 @@ class AMSoftmaxLoss(nn.Module):
|
|
562 |
psi = cosine - self.margin
|
563 |
one_hot = nn.functional.one_hot(targets, self.num_labels)
|
564 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cosine)
|
565 |
-
loss = F.cross_entropy(
|
|
|
|
|
566 |
return loss
|
567 |
|
568 |
|
@@ -577,7 +581,9 @@ class AAMSoftmaxLoss(nn.Module):
|
|
577 |
num_labels: int,
|
578 |
scale: float = 30.0,
|
579 |
margin: float = 0.35,
|
580 |
-
easy_margin: bool = False
|
|
|
|
|
581 |
):
|
582 |
"""
|
583 |
Args:
|
@@ -591,6 +597,8 @@ class AAMSoftmaxLoss(nn.Module):
|
|
591 |
self.scale = scale
|
592 |
self.margin = margin
|
593 |
self.easy_margin = easy_margin
|
|
|
|
|
594 |
|
595 |
def forward(
|
596 |
self,
|
@@ -627,7 +635,9 @@ class AAMSoftmaxLoss(nn.Module):
|
|
627 |
outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
628 |
outputs = outputs * self.scale
|
629 |
|
630 |
-
loss = F.cross_entropy(
|
|
|
|
|
631 |
return loss
|
632 |
|
633 |
|
@@ -739,17 +749,24 @@ class Wav2Vec2SpkRegForSequenceClassification(Wav2Vec2SpkRegPreTrainedModel):
|
|
739 |
)
|
740 |
elif self.config.loss_fct == 'additive_margin':
|
741 |
loss_fct = AMSoftmaxLoss(
|
742 |
-
self.config.num_labels,
|
|
|
|
|
|
|
|
|
743 |
)
|
744 |
elif self.config.loss_fct == 'additive_angular_margin':
|
745 |
loss_fct = AAMSoftmaxLoss(
|
746 |
-
self.config.num_labels,
|
|
|
|
|
|
|
|
|
|
|
747 |
)
|
748 |
loss = loss_fct(
|
749 |
logits.view(-1, self.config.num_labels),
|
750 |
labels.view(-1),
|
751 |
-
label_smoothing=self.config.label_smoothing,
|
752 |
-
reduction=self.config.reduction
|
753 |
)
|
754 |
|
755 |
if not return_dict:
|
|
|
529 |
num_labels: int,
|
530 |
scale: float = 30.0,
|
531 |
margin: float = 0.35,
|
532 |
+
label_smoothing: float = 0.0,
|
533 |
+
reduction: str = "mean"
|
534 |
):
|
535 |
"""
|
536 |
Args:
|
|
|
542 |
self.num_labels = num_labels
|
543 |
self.scale = scale
|
544 |
self.margin = margin
|
545 |
+
self.label_smoothing = label_smoothing
|
546 |
+
self.reduction = reduction
|
547 |
|
548 |
def forward(
|
549 |
self,
|
550 |
inputs: torch.Tensor,
|
551 |
targets: torch.Tensor,
|
|
|
|
|
552 |
):
|
553 |
"""
|
554 |
Args:
|
|
|
564 |
psi = cosine - self.margin
|
565 |
one_hot = nn.functional.one_hot(targets, self.num_labels)
|
566 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cosine)
|
567 |
+
loss = F.cross_entropy(
|
568 |
+
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
569 |
+
)
|
570 |
return loss
|
571 |
|
572 |
|
|
|
581 |
num_labels: int,
|
582 |
scale: float = 30.0,
|
583 |
margin: float = 0.35,
|
584 |
+
easy_margin: bool = False,
|
585 |
+
label_smoothing: float = 0.0,
|
586 |
+
reduction: str = "mean"
|
587 |
):
|
588 |
"""
|
589 |
Args:
|
|
|
597 |
self.scale = scale
|
598 |
self.margin = margin
|
599 |
self.easy_margin = easy_margin
|
600 |
+
self.label_smoothing = label_smoothing
|
601 |
+
self.reduction = reduction
|
602 |
|
603 |
def forward(
|
604 |
self,
|
|
|
635 |
outputs = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
636 |
outputs = outputs * self.scale
|
637 |
|
638 |
+
loss = F.cross_entropy(
|
639 |
+
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
|
640 |
+
)
|
641 |
return loss
|
642 |
|
643 |
|
|
|
749 |
)
|
750 |
elif self.config.loss_fct == 'additive_margin':
|
751 |
loss_fct = AMSoftmaxLoss(
|
752 |
+
self.config.num_labels,
|
753 |
+
self.config.scale,
|
754 |
+
self.config.margin,
|
755 |
+
label_smoothing=self.config.label_smoothing,
|
756 |
+
reduction=self.config.reduction
|
757 |
)
|
758 |
elif self.config.loss_fct == 'additive_angular_margin':
|
759 |
loss_fct = AAMSoftmaxLoss(
|
760 |
+
self.config.num_labels,
|
761 |
+
self.config.scale,
|
762 |
+
self.config.margin,
|
763 |
+
self.config.easy_margin,
|
764 |
+
label_smoothing=self.config.label_smoothing,
|
765 |
+
reduction=self.config.reduction
|
766 |
)
|
767 |
loss = loss_fct(
|
768 |
logits.view(-1, self.config.num_labels),
|
769 |
labels.view(-1),
|
|
|
|
|
770 |
)
|
771 |
|
772 |
if not return_dict:
|