import torch import torch.nn as nn import torch.nn.functional as F from src.helper import get_cell class Decoder(nn.Module): def __init__(self, out_sz: int, embed_sz: int, hidden_sz: int, cell_type: str, n_layers: int, dropout: float, device: str): super(Decoder, self).__init__() self.hidden_sz = hidden_sz self.n_layers = n_layers self.dropout = dropout self.cell_type = cell_type self.embedding = nn.Embedding(out_sz, embed_sz) self.device = device self.rnn = get_cell(cell_type)(input_size = embed_sz, hidden_size = hidden_sz, num_layers = n_layers, dropout = dropout) self.out = nn.Linear(hidden_sz, out_sz) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden, cell): output = self.embedding(input).view(1, 1, -1) output = F.relu(output) if(self.cell_type == "LSTM"): output, (hidden, cell) = self.rnn(output, (hidden, cell)) else: output, hidden = self.rnn(output, hidden) output = self.softmax(self.out(output[0])) return output, hidden, cell def initHidden(self): return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)