Spaces:
Running
Running
ort torch | |
import torch.nn as nn | |
from src.helper import get_cell | |
class Encoder(nn.Module): | |
def __init__(self, | |
in_sz: int, | |
embed_sz: int, | |
hidden_sz: int, | |
cell_type: str, | |
n_layers: int, | |
dropout: float, | |
device: str): | |
super(Encoder, self).__init__() | |
self.hidden_sz = hidden_sz | |
self.n_layers = n_layers | |
self.dropout = dropout | |
self.cell_type = cell_type | |
self.embedding = nn.Embedding(in_sz, embed_sz) | |
self.device = device | |
self.rnn = get_cell(cell_type)(input_size = embed_sz, | |
hidden_size = hidden_sz, | |
num_layers = n_layers, | |
dropout = dropout) | |
def forward(self, input, hidden, cell): | |
embedded = self.embedding(input).view(1, 1, -1) | |
if(self.cell_type == "LSTM"): | |
output, (hidden, cell) = self.rnn(embedded, (hidden, cell)) | |
else: | |
output, hidden = self.rnn(embedded, hidden) | |
return output, hidden, cell | |
def initHidden(self): | |
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=self.device) |