Spaces:
Runtime error
Runtime error
from typing import List | |
import torch | |
from common import utils | |
from common.utils import OutputData, InputData | |
from torch import Tensor | |
def argmax_for_seq_len(inputs, seq_lens, padding_value=-100): | |
packed_inputs = utils.pack_sequence(inputs, seq_lens) | |
outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True) | |
return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1) | |
def decode(output: OutputData, | |
target: InputData = None, | |
pred_type="slot", | |
multi_threshold=0.5, | |
ignore_index=-100, | |
return_list=True, | |
return_sentence_level=True, | |
use_multi=False, | |
use_crf=False, | |
CRF=None) -> List or Tensor: | |
""" decode output logits | |
Args: | |
output (OutputData): output logits data | |
target (InputData, optional): input data with attention mask. Defaults to None. | |
pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot". | |
multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5. | |
ignore_index (int, optional): align and pad token with ignore index. Defaults to -100. | |
return_list (bool, optional): if True return list else return torch Tensor. Defaults to True. | |
return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True. | |
use_multi (bool, optional): whether to decode to multi intent. Defaults to False. | |
use_crf (bool, optional): whether to use crf. Defaults to False. | |
CRF (CRF, optional): CRF function. Defaults to None. | |
Returns: | |
List or Tensor: decoded sequence ids | |
""" | |
if pred_type == "slot": | |
inputs = output.slot_ids | |
else: | |
inputs = output.intent_ids | |
if pred_type == "slot": | |
if not use_multi: | |
if use_crf: | |
res = CRF.decode(inputs, mask=target.attention_mask) | |
else: | |
res = torch.argmax(inputs, dim=-1) | |
else: | |
raise NotImplementedError("Multi-slot prediction is not supported.") | |
elif pred_type == "intent": | |
if not use_multi: | |
res = torch.argmax(inputs, dim=-1) | |
else: | |
res = (torch.sigmoid(inputs) > multi_threshold).nonzero() | |
if return_list: | |
res_index = res.detach().cpu().tolist() | |
res_list = [[] for _ in range(len(target.seq_lens))] | |
for item in res_index: | |
res_list[item[0]].append(item[1]) | |
return res_list | |
else: | |
return res | |
elif pred_type == "token-level-intent": | |
if not use_multi: | |
res = torch.argmax(inputs, dim=-1) | |
if not return_sentence_level: | |
return res | |
if return_list: | |
res = res.detach().cpu().tolist() | |
attention_mask = target.attention_mask | |
for i in range(attention_mask.shape[0]): | |
temp = [] | |
for j in range(attention_mask.shape[1]): | |
if attention_mask[i][j] == 1: | |
temp.append(res[i][j]) | |
else: | |
break | |
res[i] = temp | |
return [max(it, key=lambda v: it.count(v)) for it in res] | |
else: | |
seq_lens = target.seq_lens | |
if not return_sentence_level: | |
token_res = torch.cat([ | |
torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold | |
for i in range(len(seq_lens))], | |
dim=0) | |
return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index) | |
intent_index_sum = torch.cat([ | |
torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0) | |
for i in range(len(seq_lens))], | |
dim=0) | |
res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero() | |
if return_list: | |
res_index = res.detach().cpu().tolist() | |
res_list = [[] for _ in range(len(seq_lens))] | |
for item in res_index: | |
res_list[item[0]].append(item[1]) | |
return res_list | |
else: | |
return res | |
else: | |
raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.") | |
if return_list: | |
res = res.detach().cpu().tolist() | |
return res | |
def compute_loss(pred: OutputData, | |
target: InputData, | |
criterion_type="slot", | |
use_crf=False, | |
ignore_index=-100, | |
loss_fn=None, | |
use_multi=False, | |
CRF=None): | |
""" compute loss | |
Args: | |
pred (OutputData): output logits data | |
target (InputData): input golden data | |
criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot". | |
ignore_index (int, optional): compute loss with ignore index. Defaults to -100. | |
loss_fn (_type_, optional): loss function. Defaults to None. | |
use_crf (bool, optional): whether to use crf. Defaults to False. | |
CRF (CRF, optional): CRF function. Defaults to None. | |
Returns: | |
Tensor: loss result | |
""" | |
if criterion_type == "slot": | |
if use_crf: | |
return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte()) | |
else: | |
pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens) | |
target_slot = utils.pack_sequence(target.slot, target.seq_lens) | |
return loss_fn(pred_slot, target_slot) | |
elif criterion_type == "token-level-intent": | |
# TODO: Two decode function | |
intent_target = target.intent.unsqueeze(1) | |
if not use_multi: | |
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1]) | |
else: | |
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1) | |
intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens) | |
intent_target = utils.pack_sequence(intent_target, target.seq_lens) | |
return loss_fn(intent_pred, intent_target) | |
else: | |
return loss_fn(pred.intent_ids, target.intent) | |