yangwang825 commited on
Commit
9d938b7
1 Parent(s): 9d5c0dd

Update modeling_wav2vec2_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_wav2vec2_spkreg.py +16 -7
modeling_wav2vec2_spkreg.py CHANGED
@@ -578,7 +578,7 @@ class AAMSoftmaxLoss(nn.Module):
578
  def __init__(
579
  self,
580
  scale: float = 30.0,
581
- margin: float = 0.35,
582
  easy_margin: bool = False,
583
  label_smoothing: float = 0.0,
584
  reduction: str = "mean"
@@ -596,9 +596,6 @@ class AAMSoftmaxLoss(nn.Module):
596
  self.easy_margin = easy_margin
597
  self.label_smoothing = label_smoothing
598
  self.reduction = reduction
599
-
600
- self.cos_m = math.cos(self.margin)
601
- self.sin_m = math.sin(self.margin)
602
 
603
  def forward(
604
  self,
@@ -614,11 +611,23 @@ class AAMSoftmaxLoss(nn.Module):
614
  """
615
  _, num_labels = inputs.shape
616
  # `inputs` are the outputs from AngularLinear()
617
- cos_theta = torch.clamp(inputs, -1.0, 1.0)
618
- sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
619
- psi = cos_theta * self.cos_m - sin_theta * self.sin_m # cos(theta + m)
620
  # theta = torch.acos(cos_theta)
621
  # psi = torch.cos(theta + self.margin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  one_hot = nn.functional.one_hot(targets, num_labels)
623
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
624
  loss = F.cross_entropy(
 
578
  def __init__(
579
  self,
580
  scale: float = 30.0,
581
+ margin: float = 0.2,
582
  easy_margin: bool = False,
583
  label_smoothing: float = 0.0,
584
  reduction: str = "mean"
 
596
  self.easy_margin = easy_margin
597
  self.label_smoothing = label_smoothing
598
  self.reduction = reduction
 
 
 
599
 
600
  def forward(
601
  self,
 
611
  """
612
  _, num_labels = inputs.shape
613
  # `inputs` are the outputs from AngularLinear()
614
+ epsilon = 1e-6
 
 
615
  # theta = torch.acos(cos_theta)
616
  # psi = torch.cos(theta + self.margin)
617
+ cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
618
+ sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
619
+ sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
620
+
621
+ cos_m = math.cos(self.margin)
622
+ sin_m = math.sin(self.margin)
623
+ psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
624
+
625
+ if self.easy_margin:
626
+ psi = torch.where(cos_theta > 0, psi, cos_theta)
627
+ else:
628
+ # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
629
+ psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
630
+
631
  one_hot = nn.functional.one_hot(targets, num_labels)
632
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
633
  loss = F.cross_entropy(