Spaces:
Runtime error
Runtime error
File size: 6,435 Bytes
37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from typing import List
import torch
from common import utils
from common.utils import OutputData, InputData
from torch import Tensor
def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
packed_inputs = utils.pack_sequence(inputs, seq_lens)
outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)
def decode(output: OutputData,
target: InputData = None,
pred_type="slot",
multi_threshold=0.5,
ignore_index=-100,
return_list=True,
return_sentence_level=True,
use_multi=False,
use_crf=False,
CRF=None) -> List or Tensor:
""" decode output logits
Args:
output (OutputData): output logits data
target (InputData, optional): input data with attention mask. Defaults to None.
pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
List or Tensor: decoded sequence ids
"""
if pred_type == "slot":
inputs = output.slot_ids
else:
inputs = output.intent_ids
if pred_type == "slot":
if not use_multi:
if use_crf:
res = CRF.decode(inputs, mask=target.attention_mask)
else:
res = torch.argmax(inputs, dim=-1)
else:
raise NotImplementedError("Multi-slot prediction is not supported.")
elif pred_type == "intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
else:
res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(target.seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
elif pred_type == "token-level-intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
if not return_sentence_level:
return res
if return_list:
res = res.detach().cpu().tolist()
attention_mask = target.attention_mask
for i in range(attention_mask.shape[0]):
temp = []
for j in range(attention_mask.shape[1]):
if attention_mask[i][j] == 1:
temp.append(res[i][j])
else:
break
res[i] = temp
return [max(it, key=lambda v: it.count(v)) for it in res]
else:
seq_lens = target.seq_lens
if not return_sentence_level:
token_res = torch.cat([
torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
for i in range(len(seq_lens))],
dim=0)
return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)
intent_index_sum = torch.cat([
torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
for i in range(len(seq_lens))],
dim=0)
res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
else:
raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
if return_list:
res = res.detach().cpu().tolist()
return res
def compute_loss(pred: OutputData,
target: InputData,
criterion_type="slot",
use_crf=False,
ignore_index=-100,
loss_fn=None,
use_multi=False,
CRF=None):
""" compute loss
Args:
pred (OutputData): output logits data
target (InputData): input golden data
criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
loss_fn (_type_, optional): loss function. Defaults to None.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
Tensor: loss result
"""
if criterion_type == "slot":
if use_crf:
return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
else:
pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
target_slot = utils.pack_sequence(target.slot, target.seq_lens)
return loss_fn(pred_slot, target_slot)
elif criterion_type == "token-level-intent":
# TODO: Two decode function
intent_target = target.intent.unsqueeze(1)
if not use_multi:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
else:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
intent_target = utils.pack_sequence(intent_target, target.seq_lens)
return loss_fn(intent_pred, intent_target)
else:
return loss_fn(pred.intent_ids, target.intent)
|