File size: 2,782 Bytes
aed64b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# pylint: disable=E0611,E0401
import tensorflow.keras.backend as K

# ALPHA = 0.2  # used in FaceNet https://arxiv.org/pdf/1503.03832.pdf
ALPHA = 0.1  # used in Deep Speaker.


def batch_cosine_similarity(x1, x2):
    # https://en.wikipedia.org/wiki/Cosine_similarity
    # 1 = equal direction ; -1 = opposite direction
    dot = K.squeeze(K.batch_dot(x1, x2, axes=1), axis=1)
    # as values have have length 1, we don't need to divide by norm (as it is 1)
    return dot


def deep_speaker_loss(y_true, y_pred, alpha=ALPHA):
    # y_true is not used. we respect this convention:
    # y_true.shape = (batch_size, embedding_size) [not used]
    # y_pred.shape = (batch_size, embedding_size)
    # EXAMPLE:
    # _____________________________________________________
    # ANCHOR 1 (512,)
    # ANCHOR 2 (512,)
    # POS EX 1 (512,)
    # POS EX 2 (512,)
    # NEG EX 1 (512,)
    # NEG EX 2 (512,)
    # _____________________________________________________
    split = K.shape(y_pred)[0] // 3

    anchor = y_pred[0:split]
    positive_ex = y_pred[split:2 * split]
    negative_ex = y_pred[2 * split:]

    # If the loss does not decrease below ALPHA then the model does not learn anything.
    # If all anchor = positive = negative (model outputs the same vector always).
    # Then sap = san = 1. and loss = max(alpha,0) = alpha.
    # On the contrary if anchor = positive = [1] and negative = [-1].
    # Then sap = 1 and san = -1. loss = max(-1-1+0.1,0) = max(-1.9, 0) = 0.
    sap = batch_cosine_similarity(anchor, positive_ex)
    san = batch_cosine_similarity(anchor, negative_ex)
    loss = K.maximum(san - sap + alpha, 0.0)
    total_loss = K.mean(loss)
    return total_loss


if __name__ == '__main__':
    import numpy as np

    print(deep_speaker_loss(alpha=0.1, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print(deep_speaker_loss(alpha=1, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print(deep_speaker_loss(alpha=2, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print('--------------')
    print(deep_speaker_loss(alpha=2, y_true=0, y_pred=np.array([[0.6], [1.0], [0.0]])))
    print(deep_speaker_loss(alpha=1, y_true=0, y_pred=np.array([[0.6], [1.0], [0.0]])))
    print(deep_speaker_loss(alpha=0.1, y_true=0, y_pred=np.array([[0.6], [1.0], [0.0]])))
    print(deep_speaker_loss(alpha=0.2, y_true=0, y_pred=np.array([[0.6], [1.0], [0.0]])))

    print('--------------')
    print(deep_speaker_loss(alpha=2, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print(deep_speaker_loss(alpha=1, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print(deep_speaker_loss(alpha=0.1, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))
    print(deep_speaker_loss(alpha=0.2, y_true=0, y_pred=np.array([[0.9], [1.0], [-1.0]])))