|
from typing import Tuple, List |
|
|
|
import torch |
|
|
|
class RecurrentDropoutLSTMCell(torch.jit.ScriptModule): |
|
__constants__ = ['hidden_size'] |
|
|
|
def __init__(self, input_size, hidden_size, dropout=0.): |
|
super(RecurrentDropoutLSTMCell, self).__init__() |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.dropout = dropout |
|
|
|
self.W_i = torch.nn.Parameter(torch.empty(hidden_size, input_size)) |
|
self.U_i = torch.nn.Parameter(torch.empty(hidden_size, hidden_size)) |
|
|
|
self.W_f = torch.nn.Parameter(torch.empty(hidden_size, input_size)) |
|
self.U_f = torch.nn.Parameter(torch.empty(hidden_size, hidden_size)) |
|
|
|
self.W_c = torch.nn.Parameter(torch.empty(hidden_size, input_size)) |
|
self.U_c = torch.nn.Parameter(torch.empty(hidden_size, hidden_size)) |
|
|
|
self.W_o = torch.nn.Parameter(torch.empty(hidden_size, input_size)) |
|
self.U_o = torch.nn.Parameter(torch.empty(hidden_size, hidden_size)) |
|
|
|
self.bias_ih = torch.nn.Parameter(torch.empty(4 * hidden_size)) |
|
self.bias_hh = torch.nn.Parameter(torch.empty(4 * hidden_size)) |
|
|
|
self._input_dropout_mask = torch.jit.Attribute(torch.empty((), requires_grad=False), torch.Tensor) |
|
self._h_dropout_mask = torch.jit.Attribute(torch.empty((), requires_grad=False), torch.Tensor) |
|
|
|
|
|
|
|
|
|
super(torch.jit.ScriptModule, self)._register_state_dict_hook(self._hook_remove_dropout_masks_from_state_dict) |
|
super(torch.jit.ScriptModule, self)._register_load_state_dict_pre_hook(self._hook_add_dropout_masks_to_state_dict) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
torch.nn.init.orthogonal_(self.W_i) |
|
torch.nn.init.orthogonal_(self.U_i) |
|
torch.nn.init.orthogonal_(self.W_f) |
|
torch.nn.init.orthogonal_(self.U_f) |
|
torch.nn.init.orthogonal_(self.W_c) |
|
torch.nn.init.orthogonal_(self.U_c) |
|
torch.nn.init.orthogonal_(self.W_o) |
|
torch.nn.init.orthogonal_(self.U_o) |
|
self.bias_ih.data.fill_(0.) |
|
|
|
self.bias_ih.data[self.hidden_size:2 * self.hidden_size].fill_(1.) |
|
self.bias_hh.data.fill_(0.) |
|
|
|
|
|
def set_dropout_masks(self, batch_size): |
|
def constant_mask(v): |
|
return torch.tensor(v).reshape(1, 1, 1).expand(4, batch_size, -1).to(self.W_i.device) |
|
|
|
if self.dropout: |
|
if self.training: |
|
new_tensor = self.W_i.data.new |
|
self._input_dropout_mask = torch.bernoulli( |
|
new_tensor(4, batch_size, self.input_size).fill_(1 - self.dropout)) |
|
self._h_dropout_mask = torch.bernoulli( |
|
new_tensor(4, batch_size, self.hidden_size).fill_(1 - self.dropout)) |
|
else: |
|
mask = constant_mask(1 - self.dropout) |
|
self._input_dropout_mask = mask |
|
self._h_dropout_mask = mask |
|
else: |
|
mask = constant_mask(1.) |
|
self._input_dropout_mask = mask |
|
self._h_dropout_mask = mask |
|
|
|
@classmethod |
|
def _hook_remove_dropout_masks_from_state_dict(cls, instance, state_dict, prefix, local_metadata): |
|
del state_dict[prefix + '_input_dropout_mask'] |
|
del state_dict[prefix + '_h_dropout_mask'] |
|
|
|
def _hook_add_dropout_masks_to_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
|
state_dict[prefix + '_input_dropout_mask'] = self._input_dropout_mask |
|
state_dict[prefix + '_h_dropout_mask'] = self._h_dropout_mask |
|
|
|
@torch.jit.script_method |
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
hidden_state: Tuple[torch.Tensor, torch.Tensor]): |
|
h_tm1, c_tm1 = hidden_state |
|
|
|
xi_t = torch.nn.functional.linear(input * self._input_dropout_mask[0, :input.shape[0]], self.W_i) |
|
xf_t = torch.nn.functional.linear(input * self._input_dropout_mask[1, :input.shape[0]], self.W_f) |
|
xc_t = torch.nn.functional.linear(input * self._input_dropout_mask[2, :input.shape[0]], self.W_c) |
|
xo_t = torch.nn.functional.linear(input * self._input_dropout_mask[3, :input.shape[0]], self.W_o) |
|
|
|
hi_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[0, :input.shape[0]], self.U_i) |
|
hf_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[1, :input.shape[0]], self.U_f) |
|
hc_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[2, :input.shape[0]], self.U_c) |
|
ho_t = torch.nn.functional.linear(h_tm1 * self._h_dropout_mask[3, :input.shape[0]], self.U_o) |
|
|
|
i_t = torch.sigmoid(xi_t + self.bias_ih[:self.hidden_size] + hi_t + self.bias_hh[:self.hidden_size]) |
|
f_t = torch.sigmoid(xf_t + self.bias_ih[self.hidden_size:2 * self.hidden_size] + hf_t + self.bias_hh[self.hidden_size:2 * self.hidden_size]) |
|
c_t = f_t * c_tm1 + i_t * torch.tanh(xc_t + self.bias_ih[2 * self.hidden_size:3 * self.hidden_size] + hc_t + self.bias_hh[2 * self.hidden_size:3 * self.hidden_size]) |
|
o_t = torch.sigmoid(xo_t + self.bias_ih[3 * self.hidden_size:4 * self.hidden_size] + ho_t + self.bias_hh[3 * self.hidden_size:4 * self.hidden_size]) |
|
h_t = o_t * torch.tanh(c_t) |
|
|
|
return h_t, c_t |
|
|
|
|
|
class LSTM(torch.jit.ScriptModule): |
|
def __init__(self, input_size, hidden_size, bidirectional=False, dropout=0., cell_factory=RecurrentDropoutLSTMCell): |
|
super(LSTM, self).__init__() |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.bidirectional = bidirectional |
|
self.dropout = dropout |
|
self.cell_factory = cell_factory |
|
num_directions = 2 if bidirectional else 1 |
|
self.lstm_cells = [] |
|
|
|
for direction in range(num_directions): |
|
cell = cell_factory(input_size, hidden_size, dropout=dropout) |
|
self.lstm_cells.append(cell) |
|
|
|
suffix = '_reverse' if direction == 1 else '' |
|
cell_name = 'cell{}'.format(suffix) |
|
self.add_module(cell_name, cell) |
|
|
|
def forward(self, input, hidden_state=None): |
|
is_packed = isinstance(input, torch.nn.utils.rnn.PackedSequence) |
|
if not is_packed: |
|
raise NotImplementedError |
|
|
|
max_batch_size = input.batch_sizes[0] |
|
for cell in self.lstm_cells: |
|
cell.set_dropout_masks(max_batch_size) |
|
|
|
if hidden_state is None: |
|
num_directions = 2 if self.bidirectional else 1 |
|
hx = input.data.new_zeros(num_directions, |
|
max_batch_size, self.hidden_size, |
|
requires_grad=False) |
|
hidden_state = (hx, hx) |
|
|
|
forward_hidden_state = tuple(v[0] for v in hidden_state) |
|
if self.bidirectional: |
|
reverse_hidden_state = tuple(v[1] for v in hidden_state) |
|
|
|
forward_output, (forward_h, forward_c) = self._forward_packed(input.data, input.batch_sizes, forward_hidden_state) |
|
reverse_output, (reverse_h, reverse_c) = self._reverse_packed(input.data, input.batch_sizes, reverse_hidden_state) |
|
return (torch.nn.utils.rnn.PackedSequence( |
|
torch.cat((forward_output, reverse_output), dim=-1), |
|
input.batch_sizes, |
|
input.sorted_indices, |
|
input.unsorted_indices), |
|
|
|
|
|
(torch.stack((forward_h, reverse_h), dim=0), |
|
torch.stack((forward_c, reverse_c), dim=0))) |
|
|
|
output, next_hidden = self._forward_packed(input.data, input.batch_sizes, forward_hidden_state) |
|
return (torch.nn.utils.rnn.PackedSequence( |
|
output, |
|
input.batch_sizes, |
|
input.sorted_indices, |
|
input.unsorted_indices), |
|
next_hidden) |
|
|
|
@torch.jit.script_method |
|
def _forward_packed(self, input: torch.Tensor, batch_sizes: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]): |
|
|
|
|
|
|
|
step_outputs = [] |
|
hs = [] |
|
cs = [] |
|
input_offset = torch.zeros((), dtype=torch.int64) |
|
num_steps = batch_sizes.shape[0] |
|
last_batch_size = batch_sizes[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h, c = hidden_state |
|
for i in range(num_steps): |
|
batch_size = batch_sizes[i] |
|
step_input = input.narrow(0, input_offset, batch_size) |
|
input_offset += batch_size |
|
dec = last_batch_size - batch_size |
|
if dec > 0: |
|
hs.append(h[last_batch_size - dec:last_batch_size]) |
|
cs.append(c[last_batch_size - dec:last_batch_size]) |
|
h = h[:last_batch_size - dec] |
|
c = c[:last_batch_size - dec] |
|
last_batch_size = batch_size |
|
h, c = self.cell(step_input, (h, c)) |
|
step_outputs.append(h) |
|
|
|
hs.append(h) |
|
cs.append(c) |
|
hs.reverse() |
|
cs.reverse() |
|
|
|
concat_h = torch.cat(hs) |
|
concat_c = torch.cat(cs) |
|
|
|
return (torch.cat(step_outputs, dim=0), (concat_h, concat_c)) |
|
|
|
@torch.jit.script_method |
|
def _reverse_packed(self, input: torch.Tensor, batch_sizes: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]): |
|
|
|
|
|
|
|
step_outputs = [] |
|
input_offset = torch.zeros((), dtype=torch.int64) |
|
num_steps = batch_sizes.shape[0] |
|
last_batch_size = batch_sizes[-1] |
|
|
|
|
|
|
|
|
|
|
|
h, c = hidden_state |
|
input_h, input_c = hidden_state |
|
h = h[:batch_sizes[-1]] |
|
c = c[:batch_sizes[-1]] |
|
|
|
|
|
i = num_steps - 1 |
|
while i > -1: |
|
batch_size = batch_sizes[i] |
|
inc = batch_size - last_batch_size |
|
if inc > 0: |
|
h = torch.cat((h, input_h[last_batch_size:batch_size])) |
|
c = torch.cat((c, input_c[last_batch_size:batch_size])) |
|
step_input = input.narrow(0, input_offset - batch_size, batch_size) |
|
input_offset -= batch_size |
|
last_batch_size = batch_size |
|
h, c = self.cell_reverse(step_input, (h, c)) |
|
step_outputs.append(h) |
|
i -= 1 |
|
|
|
step_outputs.reverse() |
|
return (torch.cat(step_outputs, dim=0), (h, c)) |
|
|