speaker_identify / models /triplet_loss_model.py
DuyTa's picture
Upload folder using huggingface_hub
f831146 verified
from torch import nn
from abc import abstractmethod
import torch
from torch import nn
from .cross_entropy_model import FBankNetV2
class TripletLoss(nn.Module):
def __init__(self, margin):
super().__init__()
self.cosine_similarity = nn.CosineSimilarity()
self.margin = margin
def forward(self, anchor_embeddings, positive_embeddings, negative_embeddings, reduction='mean'):
# cosine distance is a measure of dissimilarity. The higher the value, more the two vectors are dissimilar
# it is calculated as (1 - cosine similarity) and ranges between (0,2)
positive_distance = 1 - self.cosine_similarity(anchor_embeddings, positive_embeddings)
negative_distance = 1 - self.cosine_similarity(anchor_embeddings, negative_embeddings)
losses = torch.max(positive_distance - negative_distance + self.margin,torch.full_like(positive_distance, 0))
if reduction == 'mean':
return torch.mean(losses)
else:
return torch.sum(losses)
class FBankTripletLossNet(FBankNetV2):
def __init__(self,num_layers, margin):
super().__init__(num_layers=num_layers)
self.loss_layer = TripletLoss(margin)
def forward(self, anchor, positive, negative):
n = anchor.shape[0]
anchor_out = self.network(anchor)
anchor_out = anchor_out.reshape(n, -1)
anchor_out = self.linear_layer(anchor_out)
positive_out = self.network(positive)
positive_out = positive_out.reshape(n, -1)
positive_out = self.linear_layer(positive_out)
negative_out = self.network(negative)
negative_out = negative_out.reshape(n, -1)
negative_out = self.linear_layer(negative_out)
return anchor_out, positive_out, negative_out
def loss(self, anchor, positive, negative, reduction='mean'):
loss_val = self.loss_layer(anchor, positive, negative, reduction)
return loss_val