File size: 6,952 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# -*- coding: utf-8 -*-

'''
@Author     : Jiangjie Chen
@Time       : 2020/10/15 16:10
@Contact    : [email protected]
@Description: 
'''

import torch
import random
import torch.nn.functional as F
import torch.nn as nn


class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, hidden_size, num_labels, hidden_dropout_prob=0.2):
        super().__init__()
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.out_proj = nn.Linear(hidden_size, num_labels, bias=False)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


def temperature_annealing(tau, step):
    if tau == 0.:
        tau = 10. if step % 5 == 0 else 1.
    return tau


def get_label_embeddings(labels, label_embedding):
    '''
    :param labels: b x 3
    :param label_embedding: 3 x h'
    :return: b x h'
    '''
    emb = torch.einsum('oi,bo->bi', label_embedding, labels)
    return emb


def soft_logic(y_i, mask, tnorm='product'):
    '''
    a^b = ab
    avb = 1 - ((1-a)(1-b))
    :param y_i: b x m x 3
    :param mask: b x m
    :param tnorm: product or godel or lukasiewicz
    :return: [b x 3]
    '''
    _sup = y_i[:, :, 2]  # b x m
    _ref = y_i[:, :, 0]  # b x m
    _sup = _sup * mask + (1 - mask)  # pppp1111
    _ref = _ref * mask  # pppp0000

    if tnorm == 'product':
        p_sup = torch.exp(torch.log(_sup).sum(1))
        p_ref = 1 - torch.exp(torch.log(1 - _ref).sum(1))
    elif tnorm == 'godel':
        p_sup = _sup.min(-1).values
        p_ref = _ref.max(-1).values
    elif tnorm == 'lukas':
        raise NotImplementedError(tnorm)
    else:
        raise NotImplementedError(tnorm)

    p_nei = 1 - p_sup - p_ref
    p_sup = torch.max(p_sup, torch.zeros_like(p_sup))
    p_ref = torch.max(p_ref, torch.zeros_like(p_ref))
    p_nei = torch.max(p_nei, torch.zeros_like(p_nei))
    logical_prob = torch.stack([p_ref, p_nei, p_sup], dim=-1)
    assert torch.lt(logical_prob, 0).to(torch.int).sum().tolist() == 0, \
        (logical_prob, _sup, _ref)
    return logical_prob  # b x 3


def build_pseudo_labels(labels, m_attn):
    '''
    :param labels: (b,)
    :param m_attn: b x m
    :return: b x m x 3
    '''
    mask = torch.gt(m_attn, 1e-16).to(torch.int)
    sup_label = torch.tensor(2).to(labels)
    nei_label = torch.tensor(1).to(labels)
    ref_label = torch.tensor(0).to(labels)
    pseudo_labels = []
    for idx, label in enumerate(labels):
        mm = mask[idx].sum(0)
        if label == 2: # SUPPORTS
            pseudo_label = F.one_hot(sup_label.repeat(mask.size(1)), num_classes=3).to(torch.float) # TODO: hyperparam

        elif label == 0: # REFUTES
            num_samples = magic_proportion(mm)
            ids = torch.topk(m_attn[idx], k=num_samples).indices
            pseudo_label = []
            for i in range(mask.size(1)):
                if i >= mm:
                    _label = torch.tensor([1/3, 1/3, 1/3]).to(labels)
                elif i in ids:
                    _label = F.one_hot(ref_label, num_classes=3).to(torch.float)
                else:
                    if random.random() > 0.5:
                        _label = torch.tensor([0., 0., 1.]).to(labels)
                    else:
                        _label = torch.tensor([0., 1., 0.]).to(labels)
                pseudo_label.append(_label)
            pseudo_label = torch.stack(pseudo_label)

        else: # NEI
            num_samples = magic_proportion(mm)
            ids = torch.topk(m_attn[idx], k=num_samples).indices
            pseudo_label = sup_label.repeat(mask.size(1))
            pseudo_label[ids] = nei_label
            pseudo_label = F.one_hot(pseudo_label, num_classes=3).to(torch.float) # TODO: hyperparam
        
        pseudo_labels.append(pseudo_label)
    return torch.stack(pseudo_labels)


def magic_proportion(m, magic_n=5):
    # 1~4: 1, 5~m: 2
    return m // magic_n + 1


def sequence_mask(lengths, max_len=None):
    """
    Creates a boolean mask from sequence lengths.
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return (torch.arange(0, max_len, device=lengths.device)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))


def collapse_w_mask(inputs, mask):
    '''
    :param inputs: b x L x h
    :param mask: b x L
    :return: b x h
    '''
    hidden = inputs.size(-1)
    output = inputs * mask.unsqueeze(-1).repeat((1, 1, hidden))  # b x L x h
    output = output.sum(-2)
    output /= (mask.sum(-1) + 1e-6).unsqueeze(-1).repeat((1, hidden))  # b x h
    return output


def parse_ce_outputs(ce_seq_output, ce_lengths):
    '''
    :param qa_seq_output: b x L1 x h
    :param qa_lengths: e.g. [0,1,1,0,2,2,0,0] (b x L2)
    :return:
        c_output: b x h
        e_output: b x h
    '''
    if ce_lengths.max() == 0:
        b, L1, h = ce_seq_output.size()
        return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda()
    masks = []
    for mask_id in range(1, ce_lengths.max() + 1):
        _m = torch.ones_like(ce_lengths) * mask_id
        mask = _m.eq(ce_lengths).to(torch.int)
        masks.append(mask)
    c_output = collapse_w_mask(ce_seq_output, masks[0])
    e_output = torch.stack([collapse_w_mask(ce_seq_output, m)
                            for m in masks[1:]]).mean(0)
    return c_output, e_output


def parse_qa_outputs(qa_seq_output, qa_lengths, k):
    '''
    :param qa_seq_output: b x L2 x h
    :param qa_lengths: e.g. [0,1,1,0,2,2,0,3,0,4,0,5,0,0,0,0] (b x L2)
    :return:
        q_output: b x h
        a_output: b x h
        k_cand_output: k x b x h
    '''
    b, L2, h = qa_seq_output.size()
    if qa_lengths.max() == 0:
        return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda(), \
               torch.zeros([k, b, h]).cuda()

    masks = []
    for mask_id in range(1, qa_lengths.max() + 1):
        _m = torch.ones_like(qa_lengths) * mask_id
        mask = _m.eq(qa_lengths).to(torch.int)
        masks.append(mask)

    q_output = collapse_w_mask(qa_seq_output, masks[0])
    a_output = collapse_w_mask(qa_seq_output, masks[1])
    k_cand_output = [collapse_w_mask(qa_seq_output, m)
                     for m in masks[2:2 + k]]
    for i in range(k - len(k_cand_output)):
        k_cand_output.append(torch.zeros([b, h]).cuda())
    k_cand_output = torch.stack(k_cand_output, dim=0)

    return q_output, a_output, k_cand_output


def attention_mask_to_mask(attention_mask):
    '''
    :param attention_mask: b x m x L
    :return: b x m
    '''
    mask = torch.gt(attention_mask.sum(-1), 0).to(torch.int).sum(-1)  # (b,)
    mask = sequence_mask(mask, max_len=attention_mask.size(1)).to(torch.int)  # (b, m)
    return mask


if __name__ == "__main__":
    y = torch.tensor([[[0.3,0.5,0.2],[0.1,0.4,0.5]]])
    mask = torch.tensor([1,1])
    s = soft_logic(y, mask)
    print(s)