File size: 1,860 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
"""
Implementation of "Convolutional Sequence to Sequence Learning"
"""
import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.utils.cnn_factory import shape_transform, StackedCNN
SCALE_WEIGHT = 0.5 ** 0.5
class CNNEncoder(EncoderBase):
"""Encoder based on "Convolutional Sequence to Sequence Learning"
:cite:`DBLP:journals/corr/GehringAGYD17`.
"""
def __init__(self, num_layers, hidden_size,
cnn_kernel_width, dropout, embeddings):
super(CNNEncoder, self).__init__()
self.embeddings = embeddings
input_size = embeddings.embedding_size
self.linear = nn.Linear(input_size, hidden_size)
self.cnn = StackedCNN(num_layers, hidden_size,
cnn_kernel_width, dropout)
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.enc_layers,
opt.enc_rnn_size,
opt.cnn_kernel_width,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings)
def forward(self, input, lengths=None, hidden=None):
"""See :class:`onmt.modules.EncoderBase.forward()`"""
self._check_args(input, lengths, hidden)
emb = self.embeddings(input)
# s_len, batch, emb_dim = emb.size()
emb = emb.transpose(0, 1).contiguous()
emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
emb_remap = self.linear(emb_reshape)
emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
emb_remap = shape_transform(emb_remap)
out = self.cnn(emb_remap)
return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \
out.squeeze(3).transpose(0, 1).contiguous(), lengths
def update_dropout(self, dropout):
self.cnn.dropout.p = dropout
|