Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
''' | |
@Author : Jiangjie Chen | |
@Time : 2020/8/18 14:40 | |
@Contact : [email protected] | |
@Description: | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import BertModel, BertPreTrainedModel | |
from .checker_utils import attention_mask_to_mask, ClassificationHead, soft_logic, build_pseudo_labels, \ | |
get_label_embeddings, temperature_annealing | |
class BertChecker(BertPreTrainedModel): | |
def __init__(self, config, logic_lambda=0.0, prior='nli', m=8, temperature=1): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.hidden_size = config.hidden_size | |
self.bert = BertModel(config) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self._lambda = logic_lambda | |
self.prior = prior | |
self.temperature = temperature | |
self._step = 0 | |
# general attention | |
self.linear_self_attn = nn.Linear(self.hidden_size, 1, bias=False) | |
self.linear_m_attn = nn.Linear(self.hidden_size * 2, 1, bias=False) | |
self.var_hidden_size = self.hidden_size // 4 | |
z_hid_size = self.num_labels * m | |
self.linear_P_theta = nn.Linear(self.hidden_size * 2 + z_hid_size, self.var_hidden_size) | |
y_hid_size = self.var_hidden_size | |
self.linear_Q_phi = nn.Linear(self.hidden_size * 2 + y_hid_size, self.var_hidden_size) | |
self.classifier = ClassificationHead(self.var_hidden_size, self.num_labels, config.hidden_dropout_prob) # label embedding for y | |
self.z_clf = self.classifier | |
self.init_weights() | |
def forward(self, claim_input_ids, claim_attention_mask, claim_token_type_ids, | |
qa_input_ids_list, qa_attention_mask_list, qa_token_type_ids_list, | |
nli_labels=None, labels=None): | |
''' | |
m: num of questions; n: num of evidence; k: num of candidate answers | |
:param claim_input_ids: b x L1 | |
:param claim_attention_mask: b x L1 | |
:param claim_token_type_ids: b x L1 | |
:param qa_input_ids_list: b x m x L2 | |
:param qa_attention_mask_list: b x m x L2 | |
:param qa_token_type_ids_list: b x m x L2 | |
:param labels: (b,) | |
:return: | |
''' | |
self._step += 1 | |
_zero = torch.tensor(0.).to(claim_input_ids.device) | |
global_output = self.bert( | |
claim_input_ids, | |
attention_mask=claim_attention_mask, | |
token_type_ids=claim_token_type_ids | |
)[0] # b x L1 x h | |
global_output = self.self_select(global_output) # b x h | |
_qa_input_ids_list = qa_input_ids_list.transpose(1, 0) # m x b x L2 | |
_qa_attention_mask_list = qa_attention_mask_list.transpose(1, 0) | |
_qa_token_type_ids_list = qa_token_type_ids_list.transpose(1, 0) | |
local_output_list = [] | |
for _inp, _attn, _token_ids in zip(_qa_input_ids_list, _qa_attention_mask_list, _qa_token_type_ids_list): | |
_local_output = self.bert(_inp, attention_mask=_attn, | |
token_type_ids=_token_ids)[0] | |
_local_output = self.self_select(_local_output) | |
local_output_list.append(_local_output) | |
local_outputs = torch.stack(local_output_list, 0) # m x b x h | |
local_outputs = local_outputs.transpose(1, 0).contiguous() # b x m x h | |
neg_elbo, loss, logic_loss = _zero, _zero, _zero | |
mask = attention_mask_to_mask(qa_attention_mask_list) | |
# b x h, b x m x h -> b x h | |
local_outputs_w, m_attn = self.local_attn(global_output, local_outputs, mask) | |
local_outputs = torch.cat([local_outputs, global_output.unsqueeze(1).repeat(1, local_outputs.size(1), 1)], -1) | |
if labels is not None: | |
# Training | |
# ======================== Q_phi ================================ | |
labels_onehot = F.one_hot(labels, num_classes=self.num_labels).to(torch.float) | |
y_star_emb = get_label_embeddings(labels_onehot, self.classifier.out_proj.weight) # b x h | |
z = self.Q_phi(local_outputs, y_star_emb) | |
z_softmax = z.softmax(-1) | |
# ======================== P_theta ============================== | |
z_gumbel = F.gumbel_softmax(z, tau=temperature_annealing(self.temperature, self._step), | |
dim=-1, hard=True) # b x m x 3 | |
y = self.P_theta(global_output, local_outputs_w, z_gumbel) | |
# ======================== soft logic =========================== | |
mask = mask.to(torch.int) | |
y_z = soft_logic(z_softmax, mask) # b x 3 | |
logic_loss = F.kl_div(y.log_softmax(-1), y_z) | |
# ======================== ELBO ================================= | |
elbo_neg_p_log = F.cross_entropy(y.view(-1, self.num_labels), labels.view(-1)) | |
if self.prior == 'nli': | |
prior = nli_labels.softmax(dim=-1) | |
elif self.prior == 'uniform': | |
prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y) | |
prior = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) | |
elif self.prior == 'logic': | |
prior = build_pseudo_labels(labels, m_attn) | |
else: | |
raise NotImplementedError(self.prior) | |
elbo_kl = F.kl_div(z_softmax.log(), prior) | |
neg_elbo = elbo_kl + elbo_neg_p_log | |
loss = (1 - abs(self._lambda)) * neg_elbo + abs(self._lambda) * logic_loss | |
else: | |
# Inference | |
if self.prior == 'nli': | |
z = nli_labels | |
elif self.prior == 'uniform': | |
prior = torch.tensor([1 / self.num_labels] * self.num_labels).to(y) | |
z = prior.unsqueeze(0).unsqueeze(0).repeat(mask.size(0), mask.size(1), 1) | |
else: | |
z = torch.rand([local_outputs.size(0), local_outputs.size(1), self.num_labels]).to(local_outputs) | |
z_softmax = z.softmax(-1) | |
for i in range(3): # N = 3 | |
z = z_softmax.argmax(-1) | |
z = F.one_hot(z, num_classes=3).to(torch.float) | |
y = self.P_theta(global_output, local_outputs_w, z) | |
y = y.softmax(-1) | |
y_emb = get_label_embeddings(y, self.classifier.out_proj.weight) | |
z = self.Q_phi(local_outputs, y_emb) | |
z_softmax = z.softmax(-1) | |
return (loss, (neg_elbo, logic_loss), y, m_attn, (z_softmax, mask)) # batch first | |
def Q_phi(self, X, y): | |
''' | |
X, y => z | |
:param X: b x m x h | |
:param y_emb: b x 3 / b x h' | |
:return: b x m x 3 (ref, nei, sup) | |
''' | |
y_expand = y.unsqueeze(1).repeat(1, X.size(1), 1) # b x m x 3/h' | |
z_hidden = self.linear_Q_phi(torch.cat([y_expand, X], dim=-1)) # b x m x h' | |
z_hidden = F.tanh(z_hidden) | |
z = self.z_clf(z_hidden) | |
return z | |
def P_theta(self, X_global, X_local, z): | |
''' | |
X, z => y* | |
:param X_global: b x h | |
:param X_local: b x m x h | |
:param z: b x m x 3 | |
:param mask: b x m | |
:return: b x 3, b x m | |
''' | |
b = z.size(0) | |
# global classification | |
_logits = torch.cat([X_local, X_global, z.reshape(b, -1)], dim=-1) | |
_logits = self.dropout(_logits) | |
_logits = self.linear_P_theta(_logits) | |
_logits = torch.tanh(_logits) | |
y = self.classifier(_logits) | |
return y | |
def self_select(self, h_x): | |
''' | |
self attention on a vector | |
:param h_x: b x L x h | |
:return: b x h | |
''' | |
w = self.dropout(self.linear_self_attn(h_x).squeeze(-1)).softmax(-1) | |
return torch.einsum('blh,bl->bh', h_x, w) | |
def local_attn(self, global_output, local_outputs, mask): | |
''' | |
:param global_output: b x h | |
:param qa_outputs: b x m x h | |
:param mask: b x m | |
:return: b x h, b x m | |
''' | |
m = local_outputs.size(1) | |
scores = self.linear_m_attn(torch.cat([global_output.unsqueeze(1).repeat(1, m, 1), | |
local_outputs], dim=-1)).squeeze(-1) # b x m | |
mask = 1 - mask | |
scores = scores.masked_fill(mask.to(torch.bool), -1e16) | |
attn = F.softmax(scores, -1) | |
return torch.einsum('bm,bmh->bh', attn, local_outputs), attn | |