File size: 6,435 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import List
import torch

from common import utils
from common.utils import OutputData, InputData
from torch import Tensor

def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
    packed_inputs = utils.pack_sequence(inputs, seq_lens)
    outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
    return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)


def decode(output: OutputData,
           target: InputData = None,
           pred_type="slot",
           multi_threshold=0.5,
           ignore_index=-100,
           return_list=True,
           return_sentence_level=True,
           use_multi=False,
           use_crf=False,
           CRF=None) -> List or Tensor:
    """ decode output logits

    Args:
        output (OutputData): output logits data
        target (InputData, optional): input data with attention mask. Defaults to None.
        pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
        multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
        ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
        return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
        return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
        use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
        use_crf (bool, optional): whether to use crf. Defaults to False.
        CRF (CRF, optional): CRF function. Defaults to None.

    Returns:
        List or Tensor: decoded sequence ids
    """
    if pred_type == "slot":
        inputs = output.slot_ids
    else:
        inputs = output.intent_ids

    if pred_type == "slot":
        if not use_multi:
            if use_crf:
                res = CRF.decode(inputs, mask=target.attention_mask)
            else:
                res = torch.argmax(inputs, dim=-1)
        else:
            raise NotImplementedError("Multi-slot prediction is not supported.")
    elif pred_type == "intent":
        if not use_multi:
            res = torch.argmax(inputs, dim=-1)
        else:
            res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
            if return_list:
                res_index = res.detach().cpu().tolist()
                res_list = [[] for _ in range(len(target.seq_lens))]
                for item in res_index:
                    res_list[item[0]].append(item[1])
                return res_list
            else:
                return res
    elif pred_type == "token-level-intent":
        if not use_multi:
            res = torch.argmax(inputs, dim=-1)
            if not return_sentence_level:
                return res
            if return_list:
                res = res.detach().cpu().tolist()
            attention_mask = target.attention_mask
            for i in range(attention_mask.shape[0]):
                temp = []
                for j in range(attention_mask.shape[1]):
                    if attention_mask[i][j] == 1:
                        temp.append(res[i][j])
                    else:
                        break
                res[i] = temp
            return [max(it, key=lambda v: it.count(v)) for it in res]
        else:
            seq_lens = target.seq_lens

            if not return_sentence_level:
                token_res = torch.cat([
                    torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
                    for i in range(len(seq_lens))],
                    dim=0)
                return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)

            intent_index_sum = torch.cat([
                torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
                for i in range(len(seq_lens))],
                dim=0)

            res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
            if return_list:
                res_index = res.detach().cpu().tolist()
                res_list = [[] for _ in range(len(seq_lens))]
                for item in res_index:
                    res_list[item[0]].append(item[1])
                return res_list
            else:
                return res
    else:
        raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
    if return_list:
        res = res.detach().cpu().tolist()
    return res


def compute_loss(pred: OutputData,
                 target: InputData,
                 criterion_type="slot",
                 use_crf=False,
                 ignore_index=-100,
                 loss_fn=None,
                 use_multi=False,
                 CRF=None):
    """ compute loss

    Args:
        pred (OutputData): output logits data
        target (InputData): input golden data
        criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
        ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
        loss_fn (_type_, optional): loss function. Defaults to None.
        use_crf (bool, optional): whether to use crf. Defaults to False.
        CRF (CRF, optional): CRF function. Defaults to None.

    Returns:
        Tensor: loss result
    """
    if criterion_type == "slot":
        if use_crf:
            return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
        else:
            pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
            target_slot = utils.pack_sequence(target.slot, target.seq_lens)
            return loss_fn(pred_slot, target_slot)
    elif criterion_type == "token-level-intent":
        # TODO: Two decode function
        intent_target = target.intent.unsqueeze(1)
        if not use_multi:
            intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
        else:
            intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
        intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
        intent_target = utils.pack_sequence(intent_target, target.seq_lens)
        return loss_fn(intent_pred, intent_target)
    else:
        return loss_fn(pred.intent_ids, target.intent)