Keras
legal
File size: 5,976 Bytes
5d58b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import os.path as osp
import pdb
import torch
import torch.nn as nn
import numpy as np
from random import *
import json
from packaging import version
import torch.distributed as dist

from .Tool_model import AutomaticWeightedLoss
from .Numeric import AttenNumeric
from .KE_model import KE_model
# from modeling_transformer import Transformer


from .bert import BertModel, BertTokenizer, BertForMaskedLM, BertConfig
import torch.nn.functional as F

from copy import deepcopy
from src.utils import torch_accuracy
# 4.21.2


def debug(input, kk, begin=None):
    aaa = deepcopy(input[0])
    if begin is None:
        aaa.input_ids = input[0].input_ids[:kk]
        aaa.attention_mask = input[0].attention_mask[:kk]
        aaa.chinese_ref = input[0].chinese_ref[:kk]
        aaa.kpi_ref = input[0].kpi_ref[:kk]
        aaa.labels = input[0].labels[:kk]
    else:
        aaa.input_ids = input[0].input_ids[begin:kk]
        aaa.attention_mask = input[0].attention_mask[begin:kk]
        aaa.chinese_ref = input[0].chinese_ref[begin:kk]
        aaa.kpi_ref = input[0].kpi_ref[begin:kk]
        aaa.labels = input[0].labels[begin:kk]

    return aaa


class HWBert(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.loss_awl = AutomaticWeightedLoss(args.awl_num, args)
        self.args = args
        self.config = BertConfig()
        model_name = args.model_name
        if args.model_name in ['TeleBert', 'TeleBert2', 'TeleBert3']:
            self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
            # MacBert来初始化 predictions layer
            if args.cls_head_init:
                tmp = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', 'MacBert'))
                self.encoder.cls.predictions = tmp.cls.predictions
        else:
            if not osp.exists(osp.join(args.data_root, 'transformer', args.model_name)):
                model_name = 'MacBert'
            self.encoder = BertForMaskedLM.from_pretrained(osp.join(args.data_root, 'transformer', model_name))
        self.numeric_model = AttenNumeric(self.args)

    # ----------------------- 主forward函数 ----------------------------------
    def forward(self, input):
        mask_loss, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(input)
        mask_loss = mask_loss.loss
        loss_dic = {}
        if not self.args.use_kpi_loss:
            kpi_loss = None
        if kpi_loss is not None:
            loss_sum = self.loss_awl(mask_loss, 0.3 * kpi_loss)
            loss_dic['kpi_loss'] = kpi_loss.item()
        else:
            loss_sum = self.loss_awl(mask_loss)
        loss_dic['mask_loss'] = mask_loss.item()
        return {
            'loss': loss_sum,
            'loss_dic': loss_dic,
            'loss_weight': self.loss_awl.params.tolist(),
            'kpi_loss_weight': kpi_loss_weight,
            'kpi_loss_dict': kpi_loss_dict
        }

    # loss_sum, loss_dic, self.loss_awl.params.tolist(), kpi_loss_weight, kpi_loss_dict

    # ----------------------------------------------------------------
    # 测试代码,计算mask是否正确
    def mask_prediction(self, inputs, tokenizer_sz, topk=(1,)):
        token_num, token_right, word_num, word_right = None, None, None, None
        outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.mask_forward(inputs)
        inputs = inputs['labels'].view(-1)
        input_list = inputs.tolist()
        # 被修改的词
        change_token_index = [i for i, x in enumerate(input_list) if x != -100]
        change_token = torch.tensor(change_token_index)
        inputs_used = inputs[change_token]
        pred = outputs.logits.view(-1, tokenizer_sz)
        pred_used = pred[change_token].cpu()
        # 返回的list
        # 计算acc
        acc, token_right = torch_accuracy(pred_used, inputs_used, topk)
        # 计算混乱分数

        token_num = inputs_used.shape[0]
        # TODO: 添加word_num, word_right
        # token_right:list
        return token_num, token_right, outputs.loss.item()

    def mask_forward(self, inputs):
        kpi_ref = None
        if 'kpi_ref' in inputs:
            kpi_ref = inputs['kpi_ref']

        outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict = self.encoder(
            input_ids=inputs['input_ids'].cuda(),
            attention_mask=inputs['attention_mask'].cuda(),
            # token_type_ids=inputs.token_type_ids.cuda(),
            labels=inputs['labels'].cuda(),
            kpi_ref=kpi_ref,
            kpi_model=self.numeric_model
        )
        return outputs, kpi_loss, kpi_loss_weight, kpi_loss_dict

    # TODO: 垂直注意力考虑:https://github.com/lucidrains/axial-attention

    def cls_embedding(self, inputs, tp='cls'):
        hidden_states = self.encoder(
            input_ids=inputs['input_ids'].cuda(),
            attention_mask=inputs['attention_mask'].cuda(),
            output_hidden_states=True)[0].hidden_states
        if tp == 'cls':
            return hidden_states[-1][:, 0]
        else:
            index_real = torch.tensor(inputs['input_ids'].clone().detach(), dtype=torch.bool)
            res = []
            for i in range(hidden_states[-1].shape[0]):
                if tp == 'last_avg':
                    res.append(hidden_states[-1][i][index_real[i]][:-1].mean(dim=0))
                elif tp == 'last2avg':
                    res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1]).mean(dim=0))
                elif tp == 'last3avg':
                    res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[-2][i][index_real[i]][:-1] + hidden_states[-3][i][index_real[i]][:-1]).mean(dim=0))
                elif tp == 'first_last_avg':
                    res.append((hidden_states[-1][i][index_real[i]][:-1] + hidden_states[1][i][index_real[i]][:-1]).mean(dim=0))

            return torch.stack(res)