wilbin's picture
Upload 248 files
8896a5f verified
from __future__ import print_function,division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence
class LastHundredEmbed(nn.Module):
def forward(self, x):
return x[:,:,-100:]
class IdentityEmbed(nn.Module):
def forward(self, x):
return x
class FullyConnectedEmbed(nn.Module):
def __init__(self, nin, nout, dropout=0.5, activation=nn.ReLU()):
super(FullyConnectedEmbed, self).__init__()
self.nin = nin
self.nout = nout
self.dropout_p = dropout
self.transform = nn.Linear(nin, nout)
self.drop = nn.Dropout(p = self.dropout_p)
self.activation = activation
def forward(self, x):
t = self.transform(x)
t = self.activation(t)
t = self.drop(t)
return t
class LMEmbed(nn.Module):
def __init__(self, nin, nout, lm, padding_idx=-1, transform=nn.ReLU()
, sparse=False):
super(LMEmbed, self).__init__()
if padding_idx == -1:
padding_idx = nin-1
self.lm = lm
self.embed = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse)
self.proj = nn.Linear(lm.hidden_size(), nout)
self.transform = transform
self.nout = nout
def forward(self, x):
packed = type(x) is PackedSequence
h_lm = self.lm.encode(x)
# embed and unpack if packed
if packed:
h = self.embed(x.data)
h_lm = h_lm.data
else:
h = self.embed(x)
# project
h_lm = self.proj(h_lm)
h = self.transform(h + h_lm)
# repack if needed
if packed:
h = PackedSequence(h, x.batch_sizes)
return h
class Linear(nn.Module):
def __init__(self, nin, nhidden, nout, padding_idx=-1,
sparse=False, lm=None):
super(Linear, self).__init__()
if padding_idx == -1:
padding_idx = nin-1
if lm is not None:
self.embed = LMEmbed(nin, nhidden, lm, padding_idx=padding_idx, sparse=sparse)
self.proj = nn.Linear(self.embed.nout, nout)
self.lm = True
else:
self.proj = nn.Embedding(nin, nout, padding_idx=padding_idx, sparse=sparse)
self.lm = False
self.nout = nout
def forward(self, x):
if self.lm:
h = self.embed(x)
if type(h) is PackedSequence:
h = h.data
z = self.proj(h)
z = PackedSequence(z, x.batch_sizes)
else:
h = h.view(-1, h.size(2))
z = self.proj(h)
z = z.view(x.size(0), x.size(1), -1)
else:
if type(x) is PackedSequence:
z = self.embed(x.data)
z = PackedSequence(z, x.batch_sizes)
else:
z = self.embed(x)
return z
class StackedRNN(nn.Module):
def __init__(self, nin, nembed, nunits, nout, nlayers=2, padding_idx=-1, dropout=0,
rnn_type='lstm', sparse=False, lm=None):
super(StackedRNN, self).__init__()
if padding_idx == -1:
padding_idx = nin-1
if lm is not None:
self.embed = LMEmbed(nin, nembed, lm, padding_idx=padding_idx, sparse=sparse)
nembed = self.embed.nout
self.lm = True
else:
self.embed = nn.Embedding(nin, nembed, padding_idx=padding_idx, sparse=sparse)
self.lm = False
if rnn_type == 'lstm':
RNN = nn.LSTM
elif rnn_type == 'gru':
RNN = nn.GRU
self.dropout = nn.Dropout(p=dropout)
if nlayers == 1:
dropout = 0
self.rnn = RNN(nembed, nunits, nlayers, batch_first=True
, bidirectional=True, dropout=dropout)
self.proj = nn.Linear(2*nunits, nout)
self.nout = nout
def forward(self, x):
if self.lm:
h = self.embed(x)
else:
if type(x) is PackedSequence:
h = self.embed(x.data)
h = PackedSequence(h, x.batch_sizes)
else:
h = self.embed(x)
h,_ = self.rnn(h)
if type(h) is PackedSequence:
h = h.data
h = self.dropout(h)
z = self.proj(h)
z = PackedSequence(z, x.batch_sizes)
else:
h = h.view(-1, h.size(2))
h = self.dropout(h)
z = self.proj(h)
z = z.view(x.size(0), x.size(1), -1)
return z