File size: 11,854 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
from typing import Tuple, List
import torch
class RecurrentDropoutLSTMCell(torch.jit.ScriptModule):
__constants__ = ['hidden_size']
def __init__(self, input_size, hidden_size, dropout=0.):
super(RecurrentDropoutLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.W_i = torch.nn.Parameter(torch.empty(hidden_size, input_size))
self.U_i = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.W_f = torch.nn.Parameter(torch.empty(hidden_size, input_size))
self.U_f = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.W_c = torch.nn.Parameter(torch.empty(hidden_size, input_size))
self.U_c = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.W_o = torch.nn.Parameter(torch.empty(hidden_size, input_size))
self.U_o = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
self.bias_ih = torch.nn.Parameter(torch.empty(4 * hidden_size))
self.bias_hh = torch.nn.Parameter(torch.empty(4 * hidden_size))
self._input_dropout_mask = torch.jit.Attribute(torch.empty((), requires_grad=False), torch.Tensor)
self._h_dropout_mask = torch.jit.Attribute(torch.empty((), requires_grad=False), torch.Tensor)
# call to super is needed because torch.jit.ScriptModule deletes the
# _register_state_dict_hook and _register_load_state_dict_pre_hook methods.
# TODO: In Torch 1.3, discontinue use of torch.jit.Attribute so that
# the dropout masks don't end up in the state dict in the first place.
super(torch.jit.ScriptModule, self)._register_state_dict_hook(self._hook_remove_dropout_masks_from_state_dict)
super(torch.jit.ScriptModule, self)._register_load_state_dict_pre_hook(self._hook_add_dropout_masks_to_state_dict)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.orthogonal_(self.W_i)
torch.nn.init.orthogonal_(self.U_i)
torch.nn.init.orthogonal_(self.W_f)
torch.nn.init.orthogonal_(self.U_f)
torch.nn.init.orthogonal_(self.W_c)
torch.nn.init.orthogonal_(self.U_c)
torch.nn.init.orthogonal_(self.W_o)
torch.nn.init.orthogonal_(self.U_o)
self.bias_ih.data.fill_(0.)
# forget gate set to 1.
self.bias_ih.data[self.hidden_size:2 * self.hidden_size].fill_(1.)
self.bias_hh.data.fill_(0.)
# TODO: the dropout mask should be stored in the state instead?
def set_dropout_masks(self, batch_size):
def constant_mask(v):
return torch.tensor(v).reshape(1, 1, 1).expand(4, batch_size, -1).to(self.W_i.device)
if self.dropout:
if self.training:
new_tensor = self.W_i.data.new
self._input_dropout_mask = torch.bernoulli(
new_tensor(4, batch_size, self.input_size).fill_(1 - self.dropout))
self._h_dropout_mask = torch.bernoulli(
new_tensor(4, batch_size, self.hidden_size).fill_(1 - self.dropout))
else:
mask = constant_mask(1 - self.dropout)
self._input_dropout_mask = mask
self._h_dropout_mask = mask
else:
mask = constant_mask(1.)
self._input_dropout_mask = mask
self._h_dropout_mask = mask
@classmethod
def _hook_remove_dropout_masks_from_state_dict(cls, instance, state_dict, prefix, local_metadata):
del state_dict[prefix + '_input_dropout_mask']
del state_dict[prefix + '_h_dropout_mask']
def _hook_add_dropout_masks_to_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
state_dict[prefix + '_input_dropout_mask'] = self._input_dropout_mask
state_dict[prefix + '_h_dropout_mask'] = self._h_dropout_mask
@torch.jit.script_method
def forward(
self,
input: torch.Tensor,
hidden_state: Tuple[torch.Tensor, torch.Tensor]):
h_tm1, c_tm1 = hidden_state
xi_t = torch.nn.functional.linear(input * self._input_dropout_mask[0, :input.shape[0]], self.W_i)
xf_t = torch.nn.functional.linear(input * self._input_dropout_mask[1, :input.shape[0]], self.W_f)
xc_t = torch.nn.functional.linear(input * self._input_dropout_mask[2, :input.shape[0]], self.W_c)
xo_t = torch.nn.functional.linear(input * self._input_dropout_mask[3, :input.shape[0]], self.W_o)
hi_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[0, :input.shape[0]], self.U_i)
hf_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[1, :input.shape[0]], self.U_f)
hc_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[2, :input.shape[0]], self.U_c)
ho_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[3, :input.shape[0]], self.U_o)
i_t = torch.sigmoid(xi_t + self.bias_ih[:self.hidden_size] + hi_t + self.bias_hh[:self.hidden_size])
f_t = torch.sigmoid(xf_t + self.bias_ih[self.hidden_size:2 * self.hidden_size] + hf_t + self.bias_hh[self.hidden_size:2 * self.hidden_size])
c_t = f_t * c_tm1 + i_t * torch.tanh(xc_t + self.bias_ih[2 * self.hidden_size:3 * self.hidden_size] + hc_t + self.bias_hh[2 * self.hidden_size:3 * self.hidden_size])
o_t = torch.sigmoid(xo_t + self.bias_ih[3 * self.hidden_size:4 * self.hidden_size] + ho_t + self.bias_hh[3 * self.hidden_size:4 * self.hidden_size])
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
class LSTM(torch.jit.ScriptModule):
def __init__(self, input_size, hidden_size, bidirectional=False, dropout=0., cell_factory=RecurrentDropoutLSTMCell):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.dropout = dropout
self.cell_factory = cell_factory
num_directions = 2 if bidirectional else 1
self.lstm_cells = []
for direction in range(num_directions):
cell = cell_factory(input_size, hidden_size, dropout=dropout)
self.lstm_cells.append(cell)
suffix = '_reverse' if direction == 1 else ''
cell_name = 'cell{}'.format(suffix)
self.add_module(cell_name, cell)
def forward(self, input, hidden_state=None):
is_packed = isinstance(input, torch.nn.utils.rnn.PackedSequence)
if not is_packed:
raise NotImplementedError
max_batch_size = input.batch_sizes[0]
for cell in self.lstm_cells:
cell.set_dropout_masks(max_batch_size)
if hidden_state is None:
num_directions = 2 if self.bidirectional else 1
hx = input.data.new_zeros(num_directions,
max_batch_size, self.hidden_size,
requires_grad=False)
hidden_state = (hx, hx)
forward_hidden_state = tuple(v[0] for v in hidden_state)
if self.bidirectional:
reverse_hidden_state = tuple(v[1] for v in hidden_state)
forward_output, (forward_h, forward_c) = self._forward_packed(input.data, input.batch_sizes, forward_hidden_state)
reverse_output, (reverse_h, reverse_c) = self._reverse_packed(input.data, input.batch_sizes, reverse_hidden_state)
return (torch.nn.utils.rnn.PackedSequence(
torch.cat((forward_output, reverse_output), dim=-1),
input.batch_sizes,
input.sorted_indices,
input.unsorted_indices),
# TODO: Support multiple layers
# TODO: Support batch_first
(torch.stack((forward_h, reverse_h), dim=0),
torch.stack((forward_c, reverse_c), dim=0)))
output, next_hidden = self._forward_packed(input.data, input.batch_sizes, forward_hidden_state)
return (torch.nn.utils.rnn.PackedSequence(
output,
input.batch_sizes,
input.sorted_indices,
input.unsorted_indices),
next_hidden)
@torch.jit.script_method
def _forward_packed(self, input: torch.Tensor, batch_sizes: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]):
# Derived from
# https://github.com/pytorch/pytorch/blob/6a4ca9abec1c18184635881c08628737c8ed2497/aten/src/ATen/native/RNN.cpp#L589
step_outputs = []
hs = []
cs = []
input_offset = torch.zeros((), dtype=torch.int64) # scalar zero
num_steps = batch_sizes.shape[0]
last_batch_size = batch_sizes[0]
# Batch sizes is a sequence of decreasing lengths, which are offsets
# into a 1D list of inputs. At every step we slice out batch_size elements,
# and possibly account for the decrease in the batch size since the last step,
# which requires us to slice the hidden state (since some sequences
# are completed now). The sliced parts are also saved, because we will need
# to return a tensor of final hidden state.
h, c = hidden_state
for i in range(num_steps):
batch_size = batch_sizes[i]
step_input = input.narrow(0, input_offset, batch_size)
input_offset += batch_size
dec = last_batch_size - batch_size
if dec > 0:
hs.append(h[last_batch_size - dec:last_batch_size])
cs.append(c[last_batch_size - dec:last_batch_size])
h = h[:last_batch_size - dec]
c = c[:last_batch_size - dec]
last_batch_size = batch_size
h, c = self.cell(step_input, (h, c))
step_outputs.append(h)
hs.append(h)
cs.append(c)
hs.reverse()
cs.reverse()
concat_h = torch.cat(hs)
concat_c = torch.cat(cs)
return (torch.cat(step_outputs, dim=0), (concat_h, concat_c))
@torch.jit.script_method
def _reverse_packed(self, input: torch.Tensor, batch_sizes: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]):
# Derived from
# https://github.com/pytorch/pytorch/blob/6a4ca9abec1c18184635881c08628737c8ed2497/aten/src/ATen/native/RNN.cpp#L650
step_outputs = []
input_offset = torch.zeros((), dtype=torch.int64) # scalar zero
num_steps = batch_sizes.shape[0]
last_batch_size = batch_sizes[-1]
# Here the situation is similar to that above, except we start out with
# the smallest batch size (and a small set of hidden states we actually use),
# and progressively expand the hidden states, as we move backwards over the
# 1D list of inputs.
h, c = hidden_state
input_h, input_c = hidden_state
h = h[:batch_sizes[-1]]
c = c[:batch_sizes[-1]]
# for i in range(num_steps - 1, -1, -1): # Not supported in torchscript 1.1, so we do a workaround:
i = num_steps - 1
while i > -1:
batch_size = batch_sizes[i]
inc = batch_size - last_batch_size
if inc > 0:
h = torch.cat((h, input_h[last_batch_size:batch_size]))
c = torch.cat((c, input_c[last_batch_size:batch_size]))
step_input = input.narrow(0, input_offset - batch_size, batch_size)
input_offset -= batch_size
last_batch_size = batch_size
h, c = self.cell_reverse(step_input, (h, c))
step_outputs.append(h)
i -= 1
step_outputs.reverse()
return (torch.cat(step_outputs, dim=0), (h, c))
|