|
"""Define RNN-based encoders.""" |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torch.nn.utils.rnn import pack_padded_sequence as pack |
|
from torch.nn.utils.rnn import pad_packed_sequence as unpack |
|
|
|
from onmt.encoders.encoder import EncoderBase |
|
from onmt.utils.rnn_factory import rnn_factory |
|
|
|
|
|
class RNNEncoder(EncoderBase): |
|
""" A generic recurrent neural network encoder. |
|
|
|
Args: |
|
rnn_type (str): |
|
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] |
|
bidirectional (bool) : use a bidirectional RNN |
|
num_layers (int) : number of stacked layers |
|
hidden_size (int) : hidden size of each layer |
|
dropout (float) : dropout value for :class:`torch.nn.Dropout` |
|
embeddings (onmt.modules.Embeddings): embedding module to use |
|
""" |
|
|
|
def __init__(self, rnn_type, bidirectional, num_layers, |
|
hidden_size, dropout=0.0, embeddings=None, |
|
use_bridge=False): |
|
super(RNNEncoder, self).__init__() |
|
assert embeddings is not None |
|
|
|
num_directions = 2 if bidirectional else 1 |
|
assert hidden_size % num_directions == 0 |
|
hidden_size = hidden_size // num_directions |
|
self.embeddings = embeddings |
|
|
|
self.rnn, self.no_pack_padded_seq = \ |
|
rnn_factory(rnn_type, |
|
input_size=embeddings.embedding_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
bidirectional=bidirectional) |
|
|
|
|
|
self.use_bridge = use_bridge |
|
if self.use_bridge: |
|
self._initialize_bridge(rnn_type, |
|
hidden_size, |
|
num_layers) |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.rnn_type, |
|
opt.brnn, |
|
opt.enc_layers, |
|
opt.enc_rnn_size, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
embeddings, |
|
opt.bridge) |
|
|
|
def forward(self, src, lengths=None): |
|
"""See :func:`EncoderBase.forward()`""" |
|
self._check_args(src, lengths) |
|
|
|
emb = self.embeddings(src) |
|
|
|
|
|
packed_emb = emb |
|
if lengths is not None and not self.no_pack_padded_seq: |
|
|
|
lengths_list = lengths.view(-1).tolist() |
|
packed_emb = pack(emb, lengths_list) |
|
|
|
memory_bank, encoder_final = self.rnn(packed_emb) |
|
|
|
if lengths is not None and not self.no_pack_padded_seq: |
|
memory_bank = unpack(memory_bank)[0] |
|
|
|
if self.use_bridge: |
|
encoder_final = self._bridge(encoder_final) |
|
return encoder_final, memory_bank, lengths |
|
|
|
def _initialize_bridge(self, rnn_type, |
|
hidden_size, |
|
num_layers): |
|
|
|
|
|
number_of_states = 2 if rnn_type == "LSTM" else 1 |
|
|
|
self.total_hidden_dim = hidden_size * num_layers |
|
|
|
|
|
self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, |
|
self.total_hidden_dim, |
|
bias=True) |
|
for _ in range(number_of_states)]) |
|
|
|
def _bridge(self, hidden): |
|
"""Forward hidden state through bridge.""" |
|
def bottle_hidden(linear, states): |
|
""" |
|
Transform from 3D to 2D, apply linear and return initial size |
|
""" |
|
size = states.size() |
|
result = linear(states.view(-1, self.total_hidden_dim)) |
|
return F.relu(result).view(size) |
|
|
|
if isinstance(hidden, tuple): |
|
outs = tuple([bottle_hidden(layer, hidden[ix]) |
|
for ix, layer in enumerate(self.bridge)]) |
|
else: |
|
outs = bottle_hidden(self.bridge[0], hidden) |
|
return outs |
|
|
|
def update_dropout(self, dropout): |
|
self.rnn.dropout = dropout |
|
|