"""Base class for encoders and generic multi encoders.""" import torch.nn as nn from onmt.utils.misc import aeq class EncoderBase(nn.Module): """ Base encoder class. Specifies the interface used by different encoder types and required by :class:`onmt.Models.NMTModel`. .. mermaid:: graph BT A[Input] subgraph RNN C[Pos 1] D[Pos 2] E[Pos N] end F[Memory_Bank] G[Final] A-->C A-->D A-->E C-->F D-->F E-->F E-->G """ @classmethod def from_opt(cls, opt, embeddings=None): raise NotImplementedError def _check_args(self, src, lengths=None, hidden=None): n_batch = src.size(1) if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_) def forward(self, src, lengths=None): """ Args: src (LongTensor): padded sequences of sparse indices ``(src_len, batch, nfeat)`` lengths (LongTensor): length of each sequence ``(batch,)`` Returns: (FloatTensor, FloatTensor, FloatTensor): * final encoder state, used to initialize decoder * memory bank for attention, ``(src_len, batch, hidden)`` * lengths """ raise NotImplementedError