jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
6.95 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/10/15 16:10
@Contact : [email protected]
@Description:
'''
import torch
import random
import torch.nn.functional as F
import torch.nn as nn
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size, num_labels, hidden_dropout_prob=0.2):
super().__init__()
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, num_labels, bias=False)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.out_proj(x)
return x
def temperature_annealing(tau, step):
if tau == 0.:
tau = 10. if step % 5 == 0 else 1.
return tau
def get_label_embeddings(labels, label_embedding):
'''
:param labels: b x 3
:param label_embedding: 3 x h'
:return: b x h'
'''
emb = torch.einsum('oi,bo->bi', label_embedding, labels)
return emb
def soft_logic(y_i, mask, tnorm='product'):
'''
a^b = ab
avb = 1 - ((1-a)(1-b))
:param y_i: b x m x 3
:param mask: b x m
:param tnorm: product or godel or lukasiewicz
:return: [b x 3]
'''
_sup = y_i[:, :, 2] # b x m
_ref = y_i[:, :, 0] # b x m
_sup = _sup * mask + (1 - mask) # pppp1111
_ref = _ref * mask # pppp0000
if tnorm == 'product':
p_sup = torch.exp(torch.log(_sup).sum(1))
p_ref = 1 - torch.exp(torch.log(1 - _ref).sum(1))
elif tnorm == 'godel':
p_sup = _sup.min(-1).values
p_ref = _ref.max(-1).values
elif tnorm == 'lukas':
raise NotImplementedError(tnorm)
else:
raise NotImplementedError(tnorm)
p_nei = 1 - p_sup - p_ref
p_sup = torch.max(p_sup, torch.zeros_like(p_sup))
p_ref = torch.max(p_ref, torch.zeros_like(p_ref))
p_nei = torch.max(p_nei, torch.zeros_like(p_nei))
logical_prob = torch.stack([p_ref, p_nei, p_sup], dim=-1)
assert torch.lt(logical_prob, 0).to(torch.int).sum().tolist() == 0, \
(logical_prob, _sup, _ref)
return logical_prob # b x 3
def build_pseudo_labels(labels, m_attn):
'''
:param labels: (b,)
:param m_attn: b x m
:return: b x m x 3
'''
mask = torch.gt(m_attn, 1e-16).to(torch.int)
sup_label = torch.tensor(2).to(labels)
nei_label = torch.tensor(1).to(labels)
ref_label = torch.tensor(0).to(labels)
pseudo_labels = []
for idx, label in enumerate(labels):
mm = mask[idx].sum(0)
if label == 2: # SUPPORTS
pseudo_label = F.one_hot(sup_label.repeat(mask.size(1)), num_classes=3).to(torch.float) # TODO: hyperparam
elif label == 0: # REFUTES
num_samples = magic_proportion(mm)
ids = torch.topk(m_attn[idx], k=num_samples).indices
pseudo_label = []
for i in range(mask.size(1)):
if i >= mm:
_label = torch.tensor([1/3, 1/3, 1/3]).to(labels)
elif i in ids:
_label = F.one_hot(ref_label, num_classes=3).to(torch.float)
else:
if random.random() > 0.5:
_label = torch.tensor([0., 0., 1.]).to(labels)
else:
_label = torch.tensor([0., 1., 0.]).to(labels)
pseudo_label.append(_label)
pseudo_label = torch.stack(pseudo_label)
else: # NEI
num_samples = magic_proportion(mm)
ids = torch.topk(m_attn[idx], k=num_samples).indices
pseudo_label = sup_label.repeat(mask.size(1))
pseudo_label[ids] = nei_label
pseudo_label = F.one_hot(pseudo_label, num_classes=3).to(torch.float) # TODO: hyperparam
pseudo_labels.append(pseudo_label)
return torch.stack(pseudo_labels)
def magic_proportion(m, magic_n=5):
# 1~4: 1, 5~m: 2
return m // magic_n + 1
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1)))
def collapse_w_mask(inputs, mask):
'''
:param inputs: b x L x h
:param mask: b x L
:return: b x h
'''
hidden = inputs.size(-1)
output = inputs * mask.unsqueeze(-1).repeat((1, 1, hidden)) # b x L x h
output = output.sum(-2)
output /= (mask.sum(-1) + 1e-6).unsqueeze(-1).repeat((1, hidden)) # b x h
return output
def parse_ce_outputs(ce_seq_output, ce_lengths):
'''
:param qa_seq_output: b x L1 x h
:param qa_lengths: e.g. [0,1,1,0,2,2,0,0] (b x L2)
:return:
c_output: b x h
e_output: b x h
'''
if ce_lengths.max() == 0:
b, L1, h = ce_seq_output.size()
return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda()
masks = []
for mask_id in range(1, ce_lengths.max() + 1):
_m = torch.ones_like(ce_lengths) * mask_id
mask = _m.eq(ce_lengths).to(torch.int)
masks.append(mask)
c_output = collapse_w_mask(ce_seq_output, masks[0])
e_output = torch.stack([collapse_w_mask(ce_seq_output, m)
for m in masks[1:]]).mean(0)
return c_output, e_output
def parse_qa_outputs(qa_seq_output, qa_lengths, k):
'''
:param qa_seq_output: b x L2 x h
:param qa_lengths: e.g. [0,1,1,0,2,2,0,3,0,4,0,5,0,0,0,0] (b x L2)
:return:
q_output: b x h
a_output: b x h
k_cand_output: k x b x h
'''
b, L2, h = qa_seq_output.size()
if qa_lengths.max() == 0:
return torch.zeros([b, h]).cuda(), torch.zeros([b, h]).cuda(), \
torch.zeros([k, b, h]).cuda()
masks = []
for mask_id in range(1, qa_lengths.max() + 1):
_m = torch.ones_like(qa_lengths) * mask_id
mask = _m.eq(qa_lengths).to(torch.int)
masks.append(mask)
q_output = collapse_w_mask(qa_seq_output, masks[0])
a_output = collapse_w_mask(qa_seq_output, masks[1])
k_cand_output = [collapse_w_mask(qa_seq_output, m)
for m in masks[2:2 + k]]
for i in range(k - len(k_cand_output)):
k_cand_output.append(torch.zeros([b, h]).cuda())
k_cand_output = torch.stack(k_cand_output, dim=0)
return q_output, a_output, k_cand_output
def attention_mask_to_mask(attention_mask):
'''
:param attention_mask: b x m x L
:return: b x m
'''
mask = torch.gt(attention_mask.sum(-1), 0).to(torch.int).sum(-1) # (b,)
mask = sequence_mask(mask, max_len=attention_mask.size(1)).to(torch.int) # (b, m)
return mask
if __name__ == "__main__":
y = torch.tensor([[[0.3,0.5,0.2],[0.1,0.4,0.5]]])
mask = torch.tensor([1,1])
s = soft_logic(y, mask)
print(s)