transliteration / src /encoder.py
Pankaj Singh Rawat
Initial commit
9e582c5
ort torch
import torch.nn as nn
from src.helper import get_cell
class Encoder(nn.Module):
def __init__(self,
in_sz: int,
embed_sz: int,
hidden_sz: int,
cell_type: str,
n_layers: int,
dropout: float,
device: str):
super(Encoder, self).__init__()
self.hidden_sz = hidden_sz
self.n_layers = n_layers
self.dropout = dropout
self.cell_type = cell_type
self.embedding = nn.Embedding(in_sz, embed_sz)
self.device = device
self.rnn = get_cell(cell_type)(input_size = embed_sz,
hidden_size = hidden_sz,
num_layers = n_layers,
dropout = dropout)
def forward(self, input, hidden, cell):
embedded = self.embedding(input).view(1, 1, -1)
if(self.cell_type == "LSTM"):
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
else:
output, hidden = self.rnn(embedded, hidden)
return output, hidden, cell
def initHidden(self):
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device)