sakharamg's picture
Uploading all files
158b61b
"""
RNN tools
"""
import torch.nn as nn
import onmt.models
def rnn_factory(rnn_type, **kwargs):
""" rnn factory, Use pytorch version when available. """
no_pack_padded_seq = False
if rnn_type == "SRU":
# SRU doesn't support PackedSequence.
no_pack_padded_seq = True
rnn = onmt.models.sru.SRU(**kwargs)
else:
rnn = getattr(nn, rnn_type)(**kwargs)
return rnn, no_pack_padded_seq