Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from src.helper import get_cell | |
class Decoder(nn.Module): | |
def __init__(self, | |
out_sz: int, | |
embed_sz: int, | |
hidden_sz: int, | |
cell_type: str, | |
n_layers: int, | |
dropout: float, | |
device: str): | |
super(Decoder, self).__init__() | |
self.hidden_sz = hidden_sz | |
self.n_layers = n_layers | |
self.dropout = dropout | |
self.cell_type = cell_type | |
self.embedding = nn.Embedding(out_sz, embed_sz) | |
self.device = device | |
self.rnn = get_cell(cell_type)(input_size = embed_sz, | |
hidden_size = hidden_sz, | |
num_layers = n_layers, | |
dropout = dropout) | |
self.out = nn.Linear(hidden_sz, out_sz) | |
self.softmax = nn.LogSoftmax(dim=1) | |
def forward(self, input, hidden, cell): | |
output = self.embedding(input).view(1, 1, -1) | |
output = F.relu(output) | |
if(self.cell_type == "LSTM"): | |
output, (hidden, cell) = self.rnn(output, (hidden, cell)) | |
else: | |
output, hidden = self.rnn(output, hidden) | |
output = self.softmax(self.out(output[0])) | |
return output, hidden, cell | |
def initHidden(self): | |
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device) | |