File size: 1,959 Bytes
f831146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torch import nn
from abc import abstractmethod

import torch
from torch import nn
from .cross_entropy_model import FBankNetV2

class TripletLoss(nn.Module):

    def __init__(self, margin):
        super().__init__()
        self.cosine_similarity = nn.CosineSimilarity()
        self.margin = margin

    def forward(self, anchor_embeddings, positive_embeddings, negative_embeddings, reduction='mean'):

        # cosine distance is a measure of dissimilarity. The higher the value, more the two vectors are dissimilar
        # it is calculated as (1 - cosine similarity) and ranges between (0,2)

        positive_distance = 1 - self.cosine_similarity(anchor_embeddings, positive_embeddings)
        negative_distance = 1 - self.cosine_similarity(anchor_embeddings, negative_embeddings)

        losses = torch.max(positive_distance - negative_distance + self.margin,torch.full_like(positive_distance, 0))
        if reduction == 'mean':
            return torch.mean(losses)
        else:
            return torch.sum(losses)


class FBankTripletLossNet(FBankNetV2):

    def __init__(self,num_layers, margin):
        super().__init__(num_layers=num_layers)
        self.loss_layer = TripletLoss(margin)

    def forward(self, anchor, positive, negative):
        n = anchor.shape[0]
        anchor_out = self.network(anchor)
        anchor_out = anchor_out.reshape(n, -1)
        anchor_out = self.linear_layer(anchor_out)

        positive_out = self.network(positive)
        positive_out = positive_out.reshape(n, -1)
        positive_out = self.linear_layer(positive_out)

        negative_out = self.network(negative)
        negative_out = negative_out.reshape(n, -1)
        negative_out = self.linear_layer(negative_out)

        return anchor_out, positive_out, negative_out

    def loss(self, anchor, positive, negative, reduction='mean'):
        loss_val = self.loss_layer(anchor, positive, negative, reduction)
        return loss_val