transliteration / src /decoder.py
Pankaj Singh Rawat
Initial commit
9e582c5
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.helper import get_cell
class Decoder(nn.Module):
def __init__(self,
out_sz: int,
embed_sz: int,
hidden_sz: int,
cell_type: str,
n_layers: int,
dropout: float,
device: str):
super(Decoder, self).__init__()
self.hidden_sz = hidden_sz
self.n_layers = n_layers
self.dropout = dropout
self.cell_type = cell_type
self.embedding = nn.Embedding(out_sz, embed_sz)
self.device = device
self.rnn = get_cell(cell_type)(input_size = embed_sz,
hidden_size = hidden_sz,
num_layers = n_layers,
dropout = dropout)
self.out = nn.Linear(hidden_sz, out_sz)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden, cell):
output = self.embedding(input).view(1, 1, -1)
output = F.relu(output)
if(self.cell_type == "LSTM"):
output, (hidden, cell) = self.rnn(output, (hidden, cell))
else:
output, hidden = self.rnn(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden, cell
def initHidden(self):
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)