jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
8.41 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/8/18 14:40
@Contact : [email protected]
@Description:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel
from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \
get_label_embeddings, temperature_annealing
class BertChecker(BertPreTrainedModel):
def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1):
super().__init__(config)
self.num_labels = config.num_labels
self.hidden_size = config.hidden_size
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self._lambda = logic_lambda
self.prior = prior
self.temperature = temperature
self._step = 0
# general attention
self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False)
self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False)
self.var_hidden_size = self.hidden_size // 4
z_hid_size = self.num_labels * m
self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size)
y_hid_size = self.var_hidden_size
self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size)
self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels, config.hidden_dropout_prob) # label embedding for y
self.z_clf = self.classifier
self.init_weights()
def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids,
qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list,
nli_labels=None, labels=None):
'''
m: num of questions; n: num of evidence; k: num of candidate answers
:param claim_input_ids: b x L1
:param claim_attention_mask: b x L1
:param claim_token_type_ids: b x L1
:param qa_input_ids_list: b x m x L2
:param qa_attention_mask_list: b x m x L2
:param qa_token_type_ids_list: b x m x L2
:param labels: (b,)
:return:
'''
self._step += 1
_zero = torch.tensor(0.).to(claim_input_ids.device)
global_output = self.bert(
claim_input_ids,
attention_mask=claim_attention_mask,
token_type_ids=claim_token_type_ids
)[0] # b x L1 x h
global_output = self.self_select(global_output) # b x h
_qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2
_qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0)
_qa_token_type_ids_list = qa_token_type_ids_list.transpose(1, 0)
local_output_list = []
for _inp, _attn, _token_ids in zip(_qa_input_ids_list, _qa_attention_mask_list, _qa_token_type_ids_list):
_local_output = self.bert(_inp, attention_mask=_attn,
token_type_ids=_token_ids)[0]
_local_output = self.self_select(_local_output)
local_output_list.append(_local_output)
local_outputs = torch.stack(local_output_list, 0) # m x b x h
local_outputs = local_outputs.transpose(1, 0).contiguous() # b x m x h
neg_elbo, loss, logic_loss = _zero, _zero, _zero
mask = attention_mask_to_mask(qa_attention_mask_list)
# b x h, b x m x h -> b x h
local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask)
local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1)
if labels is not None:
# Training
# ======================== Q_phi ================================
labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float)
y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h
z = self.Q_phi(local_outputs, y_star_emb)
z_softmax = z.softmax(-1)
# ======================== P_theta ==============================
z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step),
dim=-1, hard=True) # b x m x 3
y = self.P_theta(global_output, local_outputs_w, z_gumbel)
# ======================== soft logic ===========================
mask = mask.to(torch.int)
y_z = soft_logic(z_softmax, mask) # b x 3
logic_loss = F.kl_div(y.log_softmax(-1), y_z)
# ======================== ELBO =================================
elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1))
if self.prior == 'nli':
prior = nli_labels.softmax(dim=-1)
elif self.prior == 'uniform':
prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y)
prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
elif self.prior == 'logic':
prior = build_pseudo_labels(labels, m_attn)
else:
raise NotImplementedError(self.prior)
elbo_kl = F.kl_div(z_softmax.log(), prior)
neg_elbo = elbo_kl + elbo_neg_p_log
loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss
else:
# Inference
if self.prior == 'nli':
z = nli_labels
elif self.prior == 'uniform':
prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y)
z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1)
else:
z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs)
z_softmax = z.softmax(-1)
for i in range(3): # N = 3
z = z_softmax.argmax(-1)
z = F.one_hot(z, num_classes=3).to(torch.float)
y = self.P_theta(global_output, local_outputs_w, z)
y = y.softmax(-1)
y_emb = get_label_embeddings(y, self.classifier.out_proj.weight)
z = self.Q_phi(local_outputs, y_emb)
z_softmax = z.softmax(-1)
return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first
def Q_phi(self, X, y):
'''
X, y => z
:param X: b x m x h
:param y_emb: b x 3 / b x h'
:return: b x m x 3 (ref, nei, sup)
'''
y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h'
z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h'
z_hidden = F.tanh(z_hidden)
z = self.z_clf(z_hidden)
return z
def P_theta(self, X_global, X_local, z):
'''
X, z => y*
:param X_global: b x h
:param X_local: b x m x h
:param z: b x m x 3
:param mask: b x m
:return: b x 3, b x m
'''
b = z.size(0)
# global classification
_logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1)
_logits = self.dropout(_logits)
_logits = self.linear_P_theta(_logits)
_logits = torch.tanh(_logits)
y = self.classifier(_logits)
return y
def self_select(self, h_x):
'''
self attention on a vector
:param h_x: b x L x h
:return: b x h
'''
w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1)
return torch.einsum('blh,bl->bh', h_x, w)
def local_attn(self, global_output, local_outputs, mask):
'''
:param global_output: b x h
:param qa_outputs: b x m x h
:param mask: b x m
:return: b x h, b x m
'''
m = local_outputs.size(1)
scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1),
local_outputs], dim=-1)).squeeze(-1) # b x m
mask = 1 - mask
scores = scores.masked_fill(mask.to(torch.bool), -1e16)
attn = F.softmax(scores, -1)
return torch.einsum('bm,bmh->bh', attn, local_outputs), attn