File size: 5,733 Bytes
7d95c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn.functional as F
from torch.nn import KLDivLoss

def dot_product_scores(q_vectors, ctx_vectors):
    """

    calculates q-ctx dot product scores for every row in ctx_vector

    :param q_vector:

    :param ctx_vector:

    :return:

    """
    r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
    return r

def cosine_scores(q_vectors, ctx_vectors):
    """

    calculates q-ctx cosine scores for every row in ctx_vector

    :param q_vector:

    :param ctx_vector:

    :return:

    """
    r = F.cosine_similarity(q_vectors, ctx_vectors, dim=0)
    return r

class BiEncoderNllLoss(object):
    def __init__(self,

                 score_type="dot",

                 kd_alpha=0.5):
        self.score_type = score_type
        self.kd_alpha = kd_alpha
        self.kd = KLDivLoss(reduction="batchmean", log_target=True)
        
    def calc(

        self,

        q_vectors,

        ctx_vectors,

        kd_scores):
        """

        Computes nll loss for the given lists of question and ctx vectors.

        Return: a tuple of loss value and amount of correct predictions per batch

        """
        scores = self.get_scores(q_vectors, ctx_vectors)
        kd_scores = F.log_softmax(kd_scores, dim=1)
        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            ctx_num = ctx_vectors.size(0)
            ctx_per_q = int(ctx_num/q_num)
            no_hard = int(ctx_num/q_num - 1)
            scores = scores.view(q_num, -1)
            pre_scores = torch.randn(q_num, ctx_per_q, requires_grad=True).to("cuda")
            #pre_scores = torch.randn(q_num, ctx_per_q).to("cuda")
            for i in range(q_num):
                ctx_lst = [i]
                ctx_lst += [x for x in range((q_num+i* no_hard),(q_num+i* no_hard+ no_hard))]
                #subscores = self.get_scores(q_vectors[i], ctx_vectors[ctx_lst])
                pre_scores[i] = scores[i,[ctx_lst]]
                
            #pre_scores = scores[:,:ctx_per_q]
            
            positive_idx_per_question = [i for i in range(q_num)]

        softmax_scores = F.log_softmax(scores, dim=1)
        pre_softmax_scores = F.log_softmax(pre_scores, dim=1)

        bi_loss = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )
        
        kd_loss = self.kd(pre_softmax_scores, kd_scores)

        max_score, max_idxs = torch.max(softmax_scores, 1)
        
        loss = self.kd_alpha * bi_loss + (1 - self.kd_alpha) * kd_loss
        
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
        return loss, correct_predictions_count

    def get_scores(self, q_vector, ctx_vectors):
        f = self.get_similarity_function()
        return f(q_vector, ctx_vectors)

    def get_similarity_function(self):
        if self.score_type == "dot":
            return dot_product_scores
        else:
            return cosine_scores
    
class BiEncoderDoubleNllLoss(object):
    def __init__(self,

                 score_type="dot", 

                 alpha = 0.5):
        self.score_type = score_type
        self.alpha = alpha
        
    def calc(

        self,

        q_vectors,

        ctx_vectors):
        """

        Computes nll loss for the given lists of question and ctx vectors.

        Note that although hard_negative_idx_per_question in not currently in use, one can use it for the

        loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.

        :return: a tuple of loss value and amount of correct predictions per batch

        """
        scores = self.get_scores(q_vectors, ctx_vectors)

        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            ctx_num = ctx_vectors.size(0)
            no_hard = int(ctx_num/q_num - 1)
            scores = scores.view(q_num, -1)
            
            positive_idx_per_question = [i for i in range(q_num)]
            scores2 = torch.randn(q_num, ctx_num - no_hard).to("cuda")
            for i in range(q_num):
                hard_neg_idx = [x for x in range((q_num+i* no_hard),(q_num+i* no_hard+ no_hard))]
                random_neg = [x for x in range(ctx_num) if x not in hard_neg_idx]
                subscores = self.get_scores(q_vectors[i], ctx_vectors[random_neg])
                subscores = subscores.view(1,-1)
                scores2[i] = subscores
                
        softmax_scores = F.log_softmax(scores, dim=1)
        softmax_scores2 = F.log_softmax(scores2, dim=1)

        loss1 = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )
        loss2 = F.nll_loss(
            softmax_scores2,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )
        loss = self.alpha * loss1 + (1 - self.alpha) * loss2

        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
        return loss, correct_predictions_count

    def get_scores(self, q_vector, ctx_vectors):
        f = self.get_similarity_function()
        return f(q_vector, ctx_vectors)

    def get_similarity_function(self):
        if self.score_type == "dot":
            return dot_product_scores
        else:
            return cosine_scores