OpenSLU / model /decoder /decoder_utils.py
LightChen2333's picture
Upload 34 files
37b9e99
from typing import List
import torch
from common import utils
from common.utils import OutputData, InputData
from torch import Tensor
def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
packed_inputs = utils.pack_sequence(inputs, seq_lens)
outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)
def decode(output: OutputData,
target: InputData = None,
pred_type="slot",
multi_threshold=0.5,
ignore_index=-100,
return_list=True,
return_sentence_level=True,
use_multi=False,
use_crf=False,
CRF=None) -> List or Tensor:
""" decode output logits
Args:
output (OutputData): output logits data
target (InputData, optional): input data with attention mask. Defaults to None.
pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
List or Tensor: decoded sequence ids
"""
if pred_type == "slot":
inputs = output.slot_ids
else:
inputs = output.intent_ids
if pred_type == "slot":
if not use_multi:
if use_crf:
res = CRF.decode(inputs, mask=target.attention_mask)
else:
res = torch.argmax(inputs, dim=-1)
else:
raise NotImplementedError("Multi-slot prediction is not supported.")
elif pred_type == "intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
else:
res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(target.seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
elif pred_type == "token-level-intent":
if not use_multi:
res = torch.argmax(inputs, dim=-1)
if not return_sentence_level:
return res
if return_list:
res = res.detach().cpu().tolist()
attention_mask = target.attention_mask
for i in range(attention_mask.shape[0]):
temp = []
for j in range(attention_mask.shape[1]):
if attention_mask[i][j] == 1:
temp.append(res[i][j])
else:
break
res[i] = temp
return [max(it, key=lambda v: it.count(v)) for it in res]
else:
seq_lens = target.seq_lens
if not return_sentence_level:
token_res = torch.cat([
torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
for i in range(len(seq_lens))],
dim=0)
return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)
intent_index_sum = torch.cat([
torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
for i in range(len(seq_lens))],
dim=0)
res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
if return_list:
res_index = res.detach().cpu().tolist()
res_list = [[] for _ in range(len(seq_lens))]
for item in res_index:
res_list[item[0]].append(item[1])
return res_list
else:
return res
else:
raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
if return_list:
res = res.detach().cpu().tolist()
return res
def compute_loss(pred: OutputData,
target: InputData,
criterion_type="slot",
use_crf=False,
ignore_index=-100,
loss_fn=None,
use_multi=False,
CRF=None):
""" compute loss
Args:
pred (OutputData): output logits data
target (InputData): input golden data
criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
loss_fn (_type_, optional): loss function. Defaults to None.
use_crf (bool, optional): whether to use crf. Defaults to False.
CRF (CRF, optional): CRF function. Defaults to None.
Returns:
Tensor: loss result
"""
if criterion_type == "slot":
if use_crf:
return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
else:
pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
target_slot = utils.pack_sequence(target.slot, target.seq_lens)
return loss_fn(pred_slot, target_slot)
elif criterion_type == "token-level-intent":
# TODO: Two decode function
intent_target = target.intent.unsqueeze(1)
if not use_multi:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
else:
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
intent_target = utils.pack_sequence(intent_target, target.seq_lens)
return loss_fn(intent_pred, intent_target)
else:
return loss_fn(pred.intent_ids, target.intent)