OpenSLU / model /decoder /interaction /bi_model_interaction.py
LightChen2333's picture
Upload 34 files
37b9e99
import torch
import torch.nn.functional as F
from torch import nn
from common.utils import HiddenData
from model.decoder.interaction.base_interaction import BaseInteraction
class BiModelInteraction(BaseInteraction):
def __init__(self, **config):
super().__init__(**config)
self.intent_lstm = nn.LSTM(input_size=self.config["input_dim"], hidden_size=self.config["output_dim"],
batch_first=True,
num_layers=1)
self.slot_lstm = nn.LSTM(input_size=self.config["input_dim"] + self.config["output_dim"],
hidden_size=self.config["output_dim"], num_layers=1)
def forward(self, encode_hidden: HiddenData, **kwargs):
slot_hidden = encode_hidden.get_slot_hidden_state()
intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
batch = slot_hidden.size(0)
length = slot_hidden.size(1)
dec_init_out = torch.zeros(batch, 1, self.config["output_dim"]).to(slot_hidden.device)
hidden_state = (torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device), torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device))
slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1).transpose(1,
0) # 50 x batch x feature_size
slot_drop = F.dropout(slot_hidden, self.config["dropout_rate"])
all_out = []
for i in range(length):
if i == 0:
out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), dec_init_out), dim=-1),
hidden_state)
else:
out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), out), dim=-1), hidden_state)
all_out.append(out)
slot_output = torch.cat(all_out, dim=1) # batch x 50 x feature_size
intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
encode_hidden.get_slot_hidden_state().clone().detach()),
dim=-1)
intent_drop = F.dropout(intent_hidden, self.config["dropout_rate"])
intent_lstm_output, _ = self.intent_lstm(intent_drop)
intent_output = F.dropout(intent_lstm_output, self.config["dropout_rate"])
output_list = []
for index, slen in enumerate(seq_lens):
output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
encode_hidden.update_slot_hidden_state(slot_output)
return encode_hidden
class BiModelWithoutDecoderInteraction(BaseInteraction):
def forward(self, encode_hidden: HiddenData, **kwargs):
slot_hidden = encode_hidden.get_slot_hidden_state()
intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1) # 50 x batch x feature_size
slot_output = F.dropout(slot_hidden, self.config["dropout_rate"])
intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
encode_hidden.get_slot_hidden_state().clone().detach()),
dim=-1)
intent_output = F.dropout(intent_hidden, self.config["dropout_rate"])
output_list = []
for index, slen in enumerate(seq_lens):
output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
encode_hidden.update_slot_hidden_state(slot_output)
return encode_hidden