""" RNN tools """ import torch.nn as nn import onmt.models def rnn_factory(rnn_type, **kwargs): """ rnn factory, Use pytorch version when available. """ no_pack_padded_seq = False if rnn_type == "SRU": # SRU doesn't support PackedSequence. no_pack_padded_seq = True rnn = onmt.models.sru.SRU(**kwargs) else: rnn = getattr(nn, rnn_type)(**kwargs) return rnn, no_pack_padded_seq