File size: 2,049 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn.functional as F
from torch import nn

from common.utils import HiddenData, OutputData, InputData
from model.decoder import BaseDecoder
from model.decoder.interaction.gl_gin_interaction import LSTMEncoder


class IntentEncoder(nn.Module):
    def __init__(self,input_dim, dropout_rate):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.__intent_lstm = LSTMEncoder(
            input_dim,
            input_dim,
            dropout_rate
        )

    def forward(self, g_hiddens, seq_lens):
        intent_lstm_out = self.__intent_lstm(g_hiddens, seq_lens)
        return F.dropout(intent_lstm_out, p=self.dropout_rate, training=self.training)


class GLGINDecoder(BaseDecoder):
    def __init__(self, intent_classifier, slot_classifier, interaction=None, **config):
        super().__init__(intent_classifier, slot_classifier, interaction)
        self.config=config
        self.intent_encoder = IntentEncoder(self.intent_classifier.config["input_dim"], self.config["dropout_rate"])

    def forward(self, hidden: HiddenData, forced_slot=None, forced_intent=None, differentiable=None):
        seq_lens = hidden.inputs.attention_mask.sum(-1)
        intent_lstm_out = self.intent_encoder(hidden.slot_hidden, seq_lens)
        hidden.update_intent_hidden_state(intent_lstm_out)
        pred_intent = self.intent_classifier(hidden)
        intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),hidden.inputs,
                                                     return_list=False,
                                                     return_sentence_level=True)
        slot_hidden = self.interaction(
            hidden,
            pred_intent=pred_intent,
            intent_index=intent_index,
        )
        pred_slot = self.slot_classifier(slot_hidden)
        num_intent = self.intent_classifier.config["intent_label_num"]
        pred_slot = pred_slot.classifier_output[:, num_intent:]
        return OutputData(pred_intent, F.log_softmax(pred_slot, dim=1))