|
from __future__ import print_function,division |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import PackedSequence |
|
|
|
class LastHundredEmbed(nn.Module): |
|
|
|
def forward(self, x): |
|
return x[:,:,-100:] |
|
|
|
class IdentityEmbed(nn.Module): |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
class FullyConnectedEmbed(nn.Module): |
|
def __init__(self, nin, nout, dropout=0.5, activation=nn.ReLU()): |
|
super(FullyConnectedEmbed, self).__init__() |
|
self.nin = nin |
|
self.nout = nout |
|
self.dropout_p = dropout |
|
|
|
self.transform = nn.Linear(nin, nout) |
|
self.drop = nn.Dropout(p = self.dropout_p) |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
t = self.transform(x) |
|
t = self.activation(t) |
|
t = self.drop(t) |
|
return t |
|
|
|
class LMEmbed(nn.Module): |
|
def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU() |
|
, sparse=False): |
|
super(LMEmbed, self).__init__() |
|
|
|
if padding_idx == -1: |
|
padding_idx = nin-1 |
|
|
|
self.lm = lm |
|
self.embed = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse) |
|
self.proj = nn.Linear(lm.hidden_size(), nout) |
|
self.transform = transform |
|
self.nout = nout |
|
|
|
def forward(self, x): |
|
packed = type(x) is PackedSequence |
|
h_lm = self.lm.encode(x) |
|
|
|
|
|
if packed: |
|
h = self.embed(x.data) |
|
h_lm = h_lm.data |
|
else: |
|
h = self.embed(x) |
|
|
|
|
|
h_lm = self.proj(h_lm) |
|
h = self.transform(h + h_lm) |
|
|
|
|
|
if packed: |
|
h = PackedSequence(h, x.batch_sizes) |
|
|
|
return h |
|
|
|
|
|
class Linear(nn.Module): |
|
def __init__(self, nin, nhidden, nout, padding_idx=-1, |
|
sparse=False, lm=None): |
|
super(Linear, self).__init__() |
|
|
|
if padding_idx == -1: |
|
padding_idx = nin-1 |
|
|
|
if lm is not None: |
|
self.embed = LMEmbed(nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse) |
|
self.proj = nn.Linear(self.embed.nout, nout) |
|
self.lm = True |
|
else: |
|
self.proj = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse) |
|
self.lm = False |
|
|
|
self.nout = nout |
|
|
|
|
|
def forward(self, x): |
|
|
|
if self.lm: |
|
h = self.embed(x) |
|
if type(h) is PackedSequence: |
|
h = h.data |
|
z = self.proj(h) |
|
z = PackedSequence(z, x.batch_sizes) |
|
else: |
|
h = h.view(-1, h.size(2)) |
|
z = self.proj(h) |
|
z = z.view(x.size(0), x.size(1), -1) |
|
else: |
|
if type(x) is PackedSequence: |
|
z = self.embed(x.data) |
|
z = PackedSequence(z, x.batch_sizes) |
|
else: |
|
z = self.embed(x) |
|
|
|
return z |
|
|
|
|
|
class StackedRNN(nn.Module): |
|
def __init__(self, nin, nembed, nunits, nout, nlayers=2, padding_idx=-1, dropout=0, |
|
rnn_type='lstm', sparse=False, lm=None): |
|
super(StackedRNN, self).__init__() |
|
|
|
if padding_idx == -1: |
|
padding_idx = nin-1 |
|
|
|
if lm is not None: |
|
self.embed = LMEmbed(nin, nembed, lm, padding_idx=padding_idx, sparse=sparse) |
|
nembed = self.embed.nout |
|
self.lm = True |
|
else: |
|
self.embed = nn.Embedding(nin, nembed, padding_idx=padding_idx, sparse=sparse) |
|
self.lm = False |
|
|
|
if rnn_type == 'lstm': |
|
RNN = nn.LSTM |
|
elif rnn_type == 'gru': |
|
RNN = nn.GRU |
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
if nlayers == 1: |
|
dropout = 0 |
|
|
|
self.rnn = RNN(nembed, nunits, nlayers, batch_first=True |
|
, bidirectional=True, dropout=dropout) |
|
self.proj = nn.Linear(2*nunits, nout) |
|
self.nout = nout |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
if self.lm: |
|
h = self.embed(x) |
|
else: |
|
if type(x) is PackedSequence: |
|
h = self.embed(x.data) |
|
h = PackedSequence(h, x.batch_sizes) |
|
else: |
|
h = self.embed(x) |
|
|
|
h,_ = self.rnn(h) |
|
|
|
if type(h) is PackedSequence: |
|
h = h.data |
|
h = self.dropout(h) |
|
z = self.proj(h) |
|
z = PackedSequence(z, x.batch_sizes) |
|
else: |
|
h = h.view(-1, h.size(2)) |
|
h = self.dropout(h) |
|
z = self.proj(h) |
|
z = z.view(x.size(0), x.size(1), -1) |
|
|
|
return z |
|
|
|
|
|
|