Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch import nn | |
from common import utils | |
from common.utils import ClassifierOutputData, HiddenData | |
from model.decoder.interaction.base_interaction import BaseInteraction | |
class StackInteraction(BaseInteraction): | |
def __init__(self, **config): | |
super().__init__(**config) | |
self.intent_embedding = nn.Embedding( | |
self.config["intent_label_num"], self.config["intent_label_num"] | |
) | |
self.differentiable = config.get("differentiable") | |
self.intent_embedding.weight.data = torch.eye( | |
self.config["intent_label_num"]) | |
self.intent_embedding.weight.requires_grad = False | |
def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData): | |
if not self.differentiable: | |
_, idx_intent = intent_output.classifier_output.topk(1, dim=-1) | |
feed_intent = self.intent_embedding(idx_intent.squeeze(2)) | |
else: | |
feed_intent = intent_output.classifier_output | |
encode_hidden.update_slot_hidden_state( | |
torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1)) | |
return encode_hidden | |
def from_configured(configure_name_or_file="stack-interaction", **input_config): | |
return utils.from_configured(configure_name_or_file, | |
model_class=StackInteraction, | |
config_prefix="./config/decoder/interaction", | |
**input_config) | |