yangwang825
commited on
Commit
•
9d938b7
1
Parent(s):
9d5c0dd
Update modeling_wav2vec2_spkreg.py
Browse files- modeling_wav2vec2_spkreg.py +16 -7
modeling_wav2vec2_spkreg.py
CHANGED
@@ -578,7 +578,7 @@ class AAMSoftmaxLoss(nn.Module):
|
|
578 |
def __init__(
|
579 |
self,
|
580 |
scale: float = 30.0,
|
581 |
-
margin: float = 0.
|
582 |
easy_margin: bool = False,
|
583 |
label_smoothing: float = 0.0,
|
584 |
reduction: str = "mean"
|
@@ -596,9 +596,6 @@ class AAMSoftmaxLoss(nn.Module):
|
|
596 |
self.easy_margin = easy_margin
|
597 |
self.label_smoothing = label_smoothing
|
598 |
self.reduction = reduction
|
599 |
-
|
600 |
-
self.cos_m = math.cos(self.margin)
|
601 |
-
self.sin_m = math.sin(self.margin)
|
602 |
|
603 |
def forward(
|
604 |
self,
|
@@ -614,11 +611,23 @@ class AAMSoftmaxLoss(nn.Module):
|
|
614 |
"""
|
615 |
_, num_labels = inputs.shape
|
616 |
# `inputs` are the outputs from AngularLinear()
|
617 |
-
|
618 |
-
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
619 |
-
psi = cos_theta * self.cos_m - sin_theta * self.sin_m # cos(theta + m)
|
620 |
# theta = torch.acos(cos_theta)
|
621 |
# psi = torch.cos(theta + self.margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
623 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
624 |
loss = F.cross_entropy(
|
|
|
578 |
def __init__(
|
579 |
self,
|
580 |
scale: float = 30.0,
|
581 |
+
margin: float = 0.2,
|
582 |
easy_margin: bool = False,
|
583 |
label_smoothing: float = 0.0,
|
584 |
reduction: str = "mean"
|
|
|
596 |
self.easy_margin = easy_margin
|
597 |
self.label_smoothing = label_smoothing
|
598 |
self.reduction = reduction
|
|
|
|
|
|
|
599 |
|
600 |
def forward(
|
601 |
self,
|
|
|
611 |
"""
|
612 |
_, num_labels = inputs.shape
|
613 |
# `inputs` are the outputs from AngularLinear()
|
614 |
+
epsilon = 1e-6
|
|
|
|
|
615 |
# theta = torch.acos(cos_theta)
|
616 |
# psi = torch.cos(theta + self.margin)
|
617 |
+
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
|
618 |
+
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
619 |
+
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
|
620 |
+
|
621 |
+
cos_m = math.cos(self.margin)
|
622 |
+
sin_m = math.sin(self.margin)
|
623 |
+
psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
|
624 |
+
|
625 |
+
if self.easy_margin:
|
626 |
+
psi = torch.where(cos_theta > 0, psi, cos_theta)
|
627 |
+
else:
|
628 |
+
# Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
|
629 |
+
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
|
630 |
+
|
631 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
632 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
633 |
loss = F.cross_entropy(
|