|
import os |
|
import os.path as osp |
|
import pdb |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from random import * |
|
import json |
|
from packaging import version |
|
import torch.distributed as dist |
|
|
|
from .Tool_model import AutomaticWeightedLoss |
|
from .Numeric import AttenNumeric |
|
from .KE_model import KE_model |
|
|
|
|
|
|
|
from .bert import BertModel, BertTokenizer, BertForMaskedLM, BertConfig |
|
import torch.nn.functional as F |
|
|
|
from copy import deepcopy |
|
from src.utils import torch_accuracy |
|
|
|
|
|
|
|
def debug(input, kk, begin=None): |
|
aaa = deepcopy(input[0]) |
|
if begin is None: |
|
aaa.input_ids = input[0].input_ids[:kk] |
|
aaa.attention_mask = input[0].attention_mask[:kk] |
|
aaa.chinese_ref = input[0].chinese_ref[:kk] |
|
aaa.kpi_ref = input[0].kpi_ref[:kk] |
|
aaa.labels = input[0].labels[:kk] |
|
else: |
|
aaa.input_ids = input[0].input_ids[begin:kk] |
|
aaa.attention_mask = input[0].attention_mask[begin:kk] |
|
aaa.chinese_ref = input[0].chinese_ref[begin:kk] |
|
aaa.kpi_ref = input[0].kpi_ref[begin:kk] |
|
aaa.labels = input[0].labels[begin:kk] |
|
|
|
return aaa |
|
|
|
|
|
class HWBert(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.loss_awl = AutomaticWeightedLoss(args.awl_num, args) |
|
self.args = args |
|
self.config = BertConfig() |
|
model_name = args.model_name |
|
if args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']: |
|
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name)) |
|
|
|
if args.cls_head_init: |
|
tmp = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', 'MacBert')) |
|
self.encoder.cls.predictions = tmp.cls.predictions |
|
else: |
|
if not osp.exists(osp.join(args.data_root, 'transformer', args.model_name)): |
|
model_name = 'MacBert' |
|
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name)) |
|
self.numeric_model = AttenNumeric(self.args) |
|
|
|
|
|
def forward(self, input): |
|
mask_loss, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(input) |
|
mask_loss = mask_loss.loss |
|
loss_dic = {} |
|
if not self.args.use_kpi_loss: |
|
kpi_loss = None |
|
if kpi_loss is not None: |
|
loss_sum = self.loss_awl(mask_loss, 0.3 * kpi_loss) |
|
loss_dic['kpi_loss'] = kpi_loss.item() |
|
else: |
|
loss_sum = self.loss_awl(mask_loss) |
|
loss_dic['mask_loss'] = mask_loss.item() |
|
return { |
|
'loss': loss_sum, |
|
'loss_dic': loss_dic, |
|
'loss_weight': self.loss_awl.params.tolist(), |
|
'kpi_loss_weight': kpi_loss_weight, |
|
'kpi_loss_dict': kpi_loss_dict |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def mask_prediction(self, inputs, tokenizer_sz, topk=(1,)): |
|
token_num, token_right, word_num, word_right = None, None, None, None |
|
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(inputs) |
|
inputs = inputs['labels'].view(-1) |
|
input_list = inputs.tolist() |
|
|
|
change_token_index = [i for i, x in enumerate(input_list) if x != -100] |
|
change_token = torch.tensor(change_token_index) |
|
inputs_used = inputs[change_token] |
|
pred = outputs.logits.view(-1, tokenizer_sz) |
|
pred_used = pred[change_token].cpu() |
|
|
|
|
|
acc, token_right = torch_accuracy(pred_used, inputs_used, topk) |
|
|
|
|
|
token_num = inputs_used.shape[0] |
|
|
|
|
|
return token_num, token_right, outputs.loss.item() |
|
|
|
def mask_forward(self, inputs): |
|
kpi_ref = None |
|
if 'kpi_ref' in inputs: |
|
kpi_ref = inputs['kpi_ref'] |
|
|
|
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.encoder( |
|
input_ids=inputs['input_ids'].cuda(), |
|
attention_mask=inputs['attention_mask'].cuda(), |
|
|
|
labels=inputs['labels'].cuda(), |
|
kpi_ref=kpi_ref, |
|
kpi_model=self.numeric_model |
|
) |
|
return outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict |
|
|
|
|
|
|
|
def cls_embedding(self, inputs, tp='cls'): |
|
hidden_states = self.encoder( |
|
input_ids=inputs['input_ids'].cuda(), |
|
attention_mask=inputs['attention_mask'].cuda(), |
|
output_hidden_states=True)[0].hidden_states |
|
if tp == 'cls': |
|
return hidden_states[-1][:, 0] |
|
else: |
|
index_real = torch.tensor(inputs['input_ids'].clone().detach(), dtype=torch.bool) |
|
res = [] |
|
for i in range(hidden_states[-1].shape[0]): |
|
if tp == 'last_avg': |
|
res.append(hidden_states[-1][i][index_real[i]][:-1].mean(dim=0)) |
|
elif tp == 'last2avg': |
|
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1]).mean(dim=0)) |
|
elif tp == 'last3avg': |
|
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1] + hidden_states[-3][i][index_real[i]][:-1]).mean(dim=0)) |
|
elif tp == 'first_last_avg': |
|
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[1][i][index_real[i]][:-1]).mean(dim=0)) |
|
|
|
return torch.stack(res) |
|
|