File size: 1,535 Bytes
37b9e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import torch
from torch import nn

from common import utils
from common.utils import ClassifierOutputData, HiddenData
from model.decoder.interaction.base_interaction import BaseInteraction


class StackInteraction(BaseInteraction):
    def __init__(self, **config):
        super().__init__(**config)
        self.intent_embedding = nn.Embedding(
            self.config["intent_label_num"], self.config["intent_label_num"]
        )
        self.differentiable = config.get("differentiable")
        self.intent_embedding.weight.data = torch.eye(
            self.config["intent_label_num"])
        self.intent_embedding.weight.requires_grad = False

    def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData):
        if not self.differentiable:
            _, idx_intent = intent_output.classifier_output.topk(1, dim=-1)
            feed_intent = self.intent_embedding(idx_intent.squeeze(2))
        else:
            feed_intent = intent_output.classifier_output
        encode_hidden.update_slot_hidden_state(
            torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1))
        return encode_hidden

    @staticmethod
    def from_configured(configure_name_or_file="stack-interaction", **input_config):
        return utils.from_configured(configure_name_or_file,
                                     model_class=StackInteraction,
                                     config_prefix="./config/decoder/interaction",
                                     **input_config)