import os import torch from torch import nn from common import utils from common.utils import ClassifierOutputData, HiddenData from model.decoder.interaction.base_interaction import BaseInteraction class StackInteraction(BaseInteraction): def __init__(self, **config): super().__init__(**config) self.intent_embedding = nn.Embedding( self.config["intent_label_num"], self.config["intent_label_num"] ) self.differentiable = config.get("differentiable") self.intent_embedding.weight.data = torch.eye( self.config["intent_label_num"]) self.intent_embedding.weight.requires_grad = False def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData): if not self.differentiable: _, idx_intent = intent_output.classifier_output.topk(1, dim=-1) feed_intent = self.intent_embedding(idx_intent.squeeze(2)) else: feed_intent = intent_output.classifier_output encode_hidden.update_slot_hidden_state( torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1)) return encode_hidden @staticmethod def from_configured(configure_name_or_file="stack-interaction", **input_config): return utils.from_configured(configure_name_or_file, model_class=StackInteraction, config_prefix="./config/decoder/interaction", **input_config)