OpenSLU / model /decoder /interaction /stack_interaction.py
LightChen2333's picture
Upload 34 files
37b9e99
import os
import torch
from torch import nn
from common import utils
from common.utils import ClassifierOutputData, HiddenData
from model.decoder.interaction.base_interaction import BaseInteraction
class StackInteraction(BaseInteraction):
def __init__(self, **config):
super().__init__(**config)
self.intent_embedding = nn.Embedding(
self.config["intent_label_num"], self.config["intent_label_num"]
)
self.differentiable = config.get("differentiable")
self.intent_embedding.weight.data = torch.eye(
self.config["intent_label_num"])
self.intent_embedding.weight.requires_grad = False
def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData):
if not self.differentiable:
_, idx_intent = intent_output.classifier_output.topk(1, dim=-1)
feed_intent = self.intent_embedding(idx_intent.squeeze(2))
else:
feed_intent = intent_output.classifier_output
encode_hidden.update_slot_hidden_state(
torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1))
return encode_hidden
@staticmethod
def from_configured(configure_name_or_file="stack-interaction", **input_config):
return utils.from_configured(configure_name_or_file,
model_class=StackInteraction,
config_prefix="./config/decoder/interaction",
**input_config)