File size: 8,160 Bytes
813828b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import torch
import torch.nn.functional as F
from torch import nn
# adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
"""
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
Accepts an input of size (N, M, D)
where N is the number of speakers in the batch,
M is the number of utterances per speaker,
and D is the dimensionality of the embedding vector (e.g. d-vector)
Args:
- init_w (float): defines the initial value of w in Equation (5) of [1]
- init_b (float): definies the initial value of b in Equation (5) of [1]
"""
super().__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
print(" > Initialized Generalized End-to-End loss")
assert self.loss_method in ["softmax", "contrast"]
if self.loss_method == "softmax":
self.embed_loss = self.embed_loss_softmax
if self.loss_method == "contrast":
self.embed_loss = self.embed_loss_contrast
# pylint: disable=R0201
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
"""
Calculates the new centroids excluding the reference utterance
"""
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
excl = torch.mean(excl, 0)
new_centroids = []
for i, centroid in enumerate(centroids):
if i == spkr:
new_centroids.append(excl)
else:
new_centroids.append(centroid)
return torch.stack(new_centroids)
def calc_cosine_sim(self, dvecs, centroids):
"""
Make the cosine similarity matrix with dims (N,M,N)
"""
cos_sim_matrix = []
for spkr_idx, speaker in enumerate(dvecs):
cs_row = []
for utt_idx, utterance in enumerate(speaker):
new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
# vector based cosine similarity for speed
cs_row.append(
torch.clamp(
torch.mm(
utterance.unsqueeze(1).transpose(0, 1),
new_centroids.transpose(0, 1),
)
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
1e-6,
)
)
cs_row = torch.cat(cs_row, dim=0)
cos_sim_matrix.append(cs_row)
return torch.stack(cos_sim_matrix)
# pylint: disable=R0201
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
"""
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
"""
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
# pylint: disable=R0201
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
"""
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
"""
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]))
L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
def forward(self, x, _label=None):
"""
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
assert x.size()[1] >= 2
centroids = torch.mean(x, 1)
cos_sim_matrix = self.calc_cosine_sim(x, centroids)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = self.w * cos_sim_matrix + self.b
L = self.embed_loss(x, cos_sim_matrix)
return L.mean()
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
class AngleProtoLoss(nn.Module):
"""
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
Accepts an input of size (N, M, D)
where N is the number of speakers in the batch,
M is the number of utterances per speaker,
and D is the dimensionality of the embedding vector
Args:
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, init_w=10.0, init_b=-5.0):
super().__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
print(" > Initialized Angular Prototypical loss")
def forward(self, x, _label=None):
"""
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
assert x.size()[1] >= 2
out_anchor = torch.mean(x[:, 1:, :], 1)
out_positive = x[:, 0, :]
num_speakers = out_anchor.size()[0]
cos_sim_matrix = F.cosine_similarity(
out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),
out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2),
)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
L = self.criterion(cos_sim_matrix, label)
return L
class SoftmaxLoss(nn.Module):
"""
Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
Args:
- embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers
"""
def __init__(self, embedding_dim, n_speakers):
super().__init__()
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)
print("Initialised Softmax Loss")
def forward(self, x, label=None):
# reshape for compatibility
x = x.reshape(-1, x.size()[-1])
label = label.reshape(-1)
x = self.fc(x)
L = self.criterion(x, label)
return L
def inference(self, embedding):
x = self.fc(embedding)
activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
class_id = torch.argmax(activations)
return class_id
class SoftmaxAngleProtoLoss(nn.Module):
"""
Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
Args:
- embedding_dim (float): speaker embedding dim
- n_speakers (float): number of speakers
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
super().__init__()
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)
print("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None):
"""
Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
"""
Lp = self.angleproto(x)
Ls = self.softmax(x, label)
return Ls + Lp
|