File size: 1,418 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
"""Base class for encoders and generic multi encoders."""
import torch.nn as nn
from onmt.utils.misc import aeq
class EncoderBase(nn.Module):
"""
Base encoder class. Specifies the interface used by different encoder types
and required by :class:`onmt.Models.NMTModel`.
.. mermaid::
graph BT
A[Input]
subgraph RNN
C[Pos 1]
D[Pos 2]
E[Pos N]
end
F[Memory_Bank]
G[Final]
A-->C
A-->D
A-->E
C-->F
D-->F
E-->F
E-->G
"""
@classmethod
def from_opt(cls, opt, embeddings=None):
raise NotImplementedError
def _check_args(self, src, lengths=None, hidden=None):
n_batch = src.size(1)
if lengths is not None:
n_batch_, = lengths.size()
aeq(n_batch, n_batch_)
def forward(self, src, lengths=None):
"""
Args:
src (LongTensor):
padded sequences of sparse indices ``(src_len, batch, nfeat)``
lengths (LongTensor): length of each sequence ``(batch,)``
Returns:
(FloatTensor, FloatTensor, FloatTensor):
* final encoder state, used to initialize decoder
* memory bank for attention, ``(src_len, batch, hidden)``
* lengths
"""
raise NotImplementedError
|