Keras
legal
kevin110211's picture
Upload 51 files
5d58b52
import os
import os.path as osp
import pdb
import torch
import torch.nn as nn
import numpy as np
from random import *
import json
from packaging import version
import torch.distributed as dist
from .Tool_model import AutomaticWeightedLoss
from .Numeric import AttenNumeric
from .KE_model import KE_model
# from modeling_transformer import Transformer
from .bert import BertModel, BertTokenizer, BertForMaskedLM, BertConfig
import torch.nn.functional as F
from copy import deepcopy
from src.utils import torch_accuracy
# 4.21.2
def debug(input, kk, begin=None):
aaa = deepcopy(input[0])
if begin is None:
aaa.input_ids = input[0].input_ids[:kk]
aaa.attention_mask = input[0].attention_mask[:kk]
aaa.chinese_ref = input[0].chinese_ref[:kk]
aaa.kpi_ref = input[0].kpi_ref[:kk]
aaa.labels = input[0].labels[:kk]
else:
aaa.input_ids = input[0].input_ids[begin:kk]
aaa.attention_mask = input[0].attention_mask[begin:kk]
aaa.chinese_ref = input[0].chinese_ref[begin:kk]
aaa.kpi_ref = input[0].kpi_ref[begin:kk]
aaa.labels = input[0].labels[begin:kk]
return aaa
class HWBert(nn.Module):
def __init__(self, args):
super().__init__()
self.loss_awl = AutomaticWeightedLoss(args.awl_num, args)
self.args = args
self.config = BertConfig()
model_name = args.model_name
if args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
# MacBert来初始化 predictions layer
if args.cls_head_init:
tmp = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', 'MacBert'))
self.encoder.cls.predictions = tmp.cls.predictions
else:
if not osp.exists(osp.join(args.data_root, 'transformer', args.model_name)):
model_name = 'MacBert'
self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
self.numeric_model = AttenNumeric(self.args)
# ----------------------- 主forward函数 ----------------------------------
def forward(self, input):
mask_loss, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(input)
mask_loss = mask_loss.loss
loss_dic = {}
if not self.args.use_kpi_loss:
kpi_loss = None
if kpi_loss is not None:
loss_sum = self.loss_awl(mask_loss, 0.3 * kpi_loss)
loss_dic['kpi_loss'] = kpi_loss.item()
else:
loss_sum = self.loss_awl(mask_loss)
loss_dic['mask_loss'] = mask_loss.item()
return {
'loss': loss_sum,
'loss_dic': loss_dic,
'loss_weight': self.loss_awl.params.tolist(),
'kpi_loss_weight': kpi_loss_weight,
'kpi_loss_dict': kpi_loss_dict
}
# loss_sum, loss_dic, self.loss_awl.params.tolist(), kpi_loss_weight, kpi_loss_dict
# ----------------------------------------------------------------
# 测试代码,计算mask是否正确
def mask_prediction(self, inputs, tokenizer_sz, topk=(1,)):
token_num, token_right, word_num, word_right = None, None, None, None
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(inputs)
inputs = inputs['labels'].view(-1)
input_list = inputs.tolist()
# 被修改的词
change_token_index = [i for i, x in enumerate(input_list) if x != -100]
change_token = torch.tensor(change_token_index)
inputs_used = inputs[change_token]
pred = outputs.logits.view(-1, tokenizer_sz)
pred_used = pred[change_token].cpu()
# 返回的list
# 计算acc
acc, token_right = torch_accuracy(pred_used, inputs_used, topk)
# 计算混乱分数
token_num = inputs_used.shape[0]
# TODO: 添加word_num, word_right
# token_right:list
return token_num, token_right, outputs.loss.item()
def mask_forward(self, inputs):
kpi_ref = None
if 'kpi_ref' in inputs:
kpi_ref = inputs['kpi_ref']
outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.encoder(
input_ids=inputs['input_ids'].cuda(),
attention_mask=inputs['attention_mask'].cuda(),
# token_type_ids=inputs.token_type_ids.cuda(),
labels=inputs['labels'].cuda(),
kpi_ref=kpi_ref,
kpi_model=self.numeric_model
)
return outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict
# TODO: 垂直注意力考虑:https://github.com/lucidrains/axial-attention
def cls_embedding(self, inputs, tp='cls'):
hidden_states = self.encoder(
input_ids=inputs['input_ids'].cuda(),
attention_mask=inputs['attention_mask'].cuda(),
output_hidden_states=True)[0].hidden_states
if tp == 'cls':
return hidden_states[-1][:, 0]
else:
index_real = torch.tensor(inputs['input_ids'].clone().detach(), dtype=torch.bool)
res = []
for i in range(hidden_states[-1].shape[0]):
if tp == 'last_avg':
res.append(hidden_states[-1][i][index_real[i]][:-1].mean(dim=0))
elif tp == 'last2avg':
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1]).mean(dim=0))
elif tp == 'last3avg':
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1] + hidden_states[-3][i][index_real[i]][:-1]).mean(dim=0))
elif tp == 'first_last_avg':
res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[1][i][index_real[i]][:-1]).mean(dim=0))
return torch.stack(res)