File size: 1,959 Bytes
f831146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from torch import nn
from abc import abstractmethod
import torch
from torch import nn
from .cross_entropy_model import FBankNetV2
class TripletLoss(nn.Module):
def __init__(self, margin):
super().__init__()
self.cosine_similarity = nn.CosineSimilarity()
self.margin = margin
def forward(self, anchor_embeddings, positive_embeddings, negative_embeddings, reduction='mean'):
# cosine distance is a measure of dissimilarity. The higher the value, more the two vectors are dissimilar
# it is calculated as (1 - cosine similarity) and ranges between (0,2)
positive_distance = 1 - self.cosine_similarity(anchor_embeddings, positive_embeddings)
negative_distance = 1 - self.cosine_similarity(anchor_embeddings, negative_embeddings)
losses = torch.max(positive_distance - negative_distance + self.margin,torch.full_like(positive_distance, 0))
if reduction == 'mean':
return torch.mean(losses)
else:
return torch.sum(losses)
class FBankTripletLossNet(FBankNetV2):
def __init__(self,num_layers, margin):
super().__init__(num_layers=num_layers)
self.loss_layer = TripletLoss(margin)
def forward(self, anchor, positive, negative):
n = anchor.shape[0]
anchor_out = self.network(anchor)
anchor_out = anchor_out.reshape(n, -1)
anchor_out = self.linear_layer(anchor_out)
positive_out = self.network(positive)
positive_out = positive_out.reshape(n, -1)
positive_out = self.linear_layer(positive_out)
negative_out = self.network(negative)
negative_out = negative_out.reshape(n, -1)
negative_out = self.linear_layer(negative_out)
return anchor_out, positive_out, negative_out
def loss(self, anchor, positive, negative, reduction='mean'):
loss_val = self.loss_layer(anchor, positive, negative, reduction)
return loss_val |