File size: 3,954 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn.functional as F
from torch import nn

from common.utils import HiddenData
from model.decoder.interaction.base_interaction import BaseInteraction


class BiModelInteraction(BaseInteraction):
    def __init__(self, **config):
        super().__init__(**config)
        self.intent_lstm = nn.LSTM(input_size=self.config["input_dim"], hidden_size=self.config["output_dim"],
                                   batch_first=True,
                                   num_layers=1)
        self.slot_lstm = nn.LSTM(input_size=self.config["input_dim"] + self.config["output_dim"],
                                 hidden_size=self.config["output_dim"], num_layers=1)

    def forward(self, encode_hidden: HiddenData, **kwargs):
        slot_hidden = encode_hidden.get_slot_hidden_state()
        intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
        seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
        batch = slot_hidden.size(0)
        length = slot_hidden.size(1)
        dec_init_out = torch.zeros(batch, 1, self.config["output_dim"]).to(slot_hidden.device)
        hidden_state = (torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device), torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device))
        slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1).transpose(1,
                                                                                         0)  # 50 x batch x feature_size
        slot_drop = F.dropout(slot_hidden, self.config["dropout_rate"])
        all_out = []
        for i in range(length):
            if i == 0:
                out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), dec_init_out), dim=-1),
                                                   hidden_state)
            else:
                out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), out), dim=-1), hidden_state)
            all_out.append(out)
        slot_output = torch.cat(all_out, dim=1)  # batch x 50 x feature_size

        intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
                                   encode_hidden.get_slot_hidden_state().clone().detach()),
                                  dim=-1)
        intent_drop = F.dropout(intent_hidden, self.config["dropout_rate"])
        intent_lstm_output, _ = self.intent_lstm(intent_drop)
        intent_output = F.dropout(intent_lstm_output, self.config["dropout_rate"])
        output_list = []
        for index, slen in enumerate(seq_lens):
            output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))

        encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
        encode_hidden.update_slot_hidden_state(slot_output)

        return encode_hidden


class BiModelWithoutDecoderInteraction(BaseInteraction):
    def forward(self, encode_hidden: HiddenData, **kwargs):
        slot_hidden = encode_hidden.get_slot_hidden_state()
        intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
        seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
        slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1)  # 50 x batch x feature_size
        slot_output = F.dropout(slot_hidden, self.config["dropout_rate"])

        intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
                                   encode_hidden.get_slot_hidden_state().clone().detach()),
                                  dim=-1)
        intent_output = F.dropout(intent_hidden, self.config["dropout_rate"])
        output_list = []
        for index, slen in enumerate(seq_lens):
            output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))

        encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
        encode_hidden.update_slot_hidden_state(slot_output)

        return encode_hidden