Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Definitions of model layers/NN modules""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# ------------------------------------------------------------------------------ | |
# Modules | |
# ------------------------------------------------------------------------------ | |
class StackedBRNN(nn.Module): | |
"""Stacked Bi-directional RNNs. | |
Differs from standard PyTorch library in that it has the option to save | |
and concat the hidden states between layers. (i.e. the output hidden size | |
for each sequence input is num_layers * hidden_size). | |
""" | |
def __init__(self, input_size, hidden_size, num_layers, | |
dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, | |
concat_layers=False, padding=False): | |
super(StackedBRNN, self).__init__() | |
self.padding = padding | |
self.dropout_output = dropout_output | |
self.dropout_rate = dropout_rate | |
self.num_layers = num_layers | |
self.concat_layers = concat_layers | |
self.rnns = nn.ModuleList() | |
for i in range(num_layers): | |
input_size = input_size if i == 0 else 2 * hidden_size | |
self.rnns.append(rnn_type(input_size, hidden_size, | |
num_layers=1, | |
bidirectional=True)) | |
def forward(self, x, x_mask): | |
"""Encode either padded or non-padded sequences. | |
Can choose to either handle or ignore variable length sequences. | |
Always handle padding in eval. | |
Args: | |
x: batch * len * hdim | |
x_mask: batch * len (1 for padding, 0 for true) | |
Output: | |
x_encoded: batch * len * hdim_encoded | |
""" | |
if x_mask.data.sum() == 0: | |
# No padding necessary. | |
output = self._forward_unpadded(x, x_mask) | |
elif self.padding or not self.training: | |
# Pad if we care or if its during eval. | |
output = self._forward_padded(x, x_mask) | |
else: | |
# We don't care. | |
output = self._forward_unpadded(x, x_mask) | |
return output.contiguous() | |
def _forward_unpadded(self, x, x_mask): | |
"""Faster encoding that ignores any padding.""" | |
# Transpose batch and sequence dims | |
x = x.transpose(0, 1) | |
# Encode all layers | |
outputs = [x] | |
for i in range(self.num_layers): | |
rnn_input = outputs[-1] | |
# Apply dropout to hidden input | |
if self.dropout_rate > 0: | |
rnn_input = F.dropout(rnn_input, | |
p=self.dropout_rate, | |
training=self.training) | |
# Forward | |
rnn_output = self.rnns[i](rnn_input)[0] | |
outputs.append(rnn_output) | |
# Concat hidden layers | |
if self.concat_layers: | |
output = torch.cat(outputs[1:], 2) | |
else: | |
output = outputs[-1] | |
# Transpose back | |
output = output.transpose(0, 1) | |
# Dropout on output layer | |
if self.dropout_output and self.dropout_rate > 0: | |
output = F.dropout(output, | |
p=self.dropout_rate, | |
training=self.training) | |
return output | |
def _forward_padded(self, x, x_mask): | |
"""Slower (significantly), but more precise, encoding that handles | |
padding. | |
""" | |
# Compute sorted sequence lengths | |
lengths = x_mask.data.eq(0).long().sum(1).squeeze() | |
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |
_, idx_unsort = torch.sort(idx_sort, dim=0) | |
lengths = list(lengths[idx_sort]) | |
# Sort x | |
x = x.index_select(0, idx_sort) | |
# Transpose batch and sequence dims | |
x = x.transpose(0, 1) | |
# Pack it up | |
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) | |
# Encode all layers | |
outputs = [rnn_input] | |
for i in range(self.num_layers): | |
rnn_input = outputs[-1] | |
# Apply dropout to input | |
if self.dropout_rate > 0: | |
dropout_input = F.dropout(rnn_input.data, | |
p=self.dropout_rate, | |
training=self.training) | |
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, | |
rnn_input.batch_sizes) | |
outputs.append(self.rnns[i](rnn_input)[0]) | |
# Unpack everything | |
for i, o in enumerate(outputs[1:], 1): | |
outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] | |
# Concat hidden layers or take final | |
if self.concat_layers: | |
output = torch.cat(outputs[1:], 2) | |
else: | |
output = outputs[-1] | |
# Transpose and unsort | |
output = output.transpose(0, 1) | |
output = output.index_select(0, idx_unsort) | |
# Pad up to original batch sequence length | |
if output.size(1) != x_mask.size(1): | |
padding = torch.zeros(output.size(0), | |
x_mask.size(1) - output.size(1), | |
output.size(2)).type(output.data.type()) | |
output = torch.cat([output, padding], 1) | |
# Dropout on output layer | |
if self.dropout_output and self.dropout_rate > 0: | |
output = F.dropout(output, | |
p=self.dropout_rate, | |
training=self.training) | |
return output | |
class SeqAttnMatch(nn.Module): | |
"""Given sequences X and Y, match sequence Y to each element in X. | |
* o_i = sum(alpha_j * y_j) for i in X | |
* alpha_j = softmax(y_j * x_i) | |
""" | |
def __init__(self, input_size, identity=False): | |
super(SeqAttnMatch, self).__init__() | |
if not identity: | |
self.linear = nn.Linear(input_size, input_size) | |
else: | |
self.linear = None | |
def forward(self, x, y, y_mask): | |
""" | |
Args: | |
x: batch * len1 * hdim | |
y: batch * len2 * hdim | |
y_mask: batch * len2 (1 for padding, 0 for true) | |
Output: | |
matched_seq: batch * len1 * hdim | |
""" | |
# Project vectors | |
if self.linear: | |
x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) | |
x_proj = F.relu(x_proj) | |
y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) | |
y_proj = F.relu(y_proj) | |
else: | |
x_proj = x | |
y_proj = y | |
# Compute scores | |
scores = x_proj.bmm(y_proj.transpose(2, 1)) | |
# Mask padding | |
y_mask = y_mask.unsqueeze(1).expand(scores.size()) | |
scores.data.masked_fill_(y_mask.data, -float('inf')) | |
# Normalize with softmax | |
alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) | |
alpha = alpha_flat.view(-1, x.size(1), y.size(1)) | |
# Take weighted average | |
matched_seq = alpha.bmm(y) | |
return matched_seq | |
class BilinearSeqAttn(nn.Module): | |
"""A bilinear attention layer over a sequence X w.r.t y: | |
* o_i = softmax(x_i'Wy) for x_i in X. | |
Optionally don't normalize output weights. | |
""" | |
def __init__(self, x_size, y_size, identity=False, normalize=True): | |
super(BilinearSeqAttn, self).__init__() | |
self.normalize = normalize | |
# If identity is true, we just use a dot product without transformation. | |
if not identity: | |
self.linear = nn.Linear(y_size, x_size) | |
else: | |
self.linear = None | |
def forward(self, x, y, x_mask): | |
""" | |
Args: | |
x: batch * len * hdim1 | |
y: batch * hdim2 | |
x_mask: batch * len (1 for padding, 0 for true) | |
Output: | |
alpha = batch * len | |
""" | |
Wy = self.linear(y) if self.linear is not None else y | |
xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) | |
xWy.data.masked_fill_(x_mask.data, -float('inf')) | |
if self.normalize: | |
if self.training: | |
# In training we output log-softmax for NLL | |
alpha = F.log_softmax(xWy, dim=-1) | |
else: | |
# ...Otherwise 0-1 probabilities | |
alpha = F.softmax(xWy, dim=-1) | |
else: | |
alpha = xWy.exp() | |
return alpha | |
class LinearSeqAttn(nn.Module): | |
"""Self attention over a sequence: | |
* o_i = softmax(Wx_i) for x_i in X. | |
""" | |
def __init__(self, input_size): | |
super(LinearSeqAttn, self).__init__() | |
self.linear = nn.Linear(input_size, 1) | |
def forward(self, x, x_mask): | |
""" | |
Args: | |
x: batch * len * hdim | |
x_mask: batch * len (1 for padding, 0 for true) | |
Output: | |
alpha: batch * len | |
""" | |
x_flat = x.view(-1, x.size(-1)) | |
scores = self.linear(x_flat).view(x.size(0), x.size(1)) | |
scores.data.masked_fill_(x_mask.data, -float('inf')) | |
alpha = F.softmax(scores, dim=-1) | |
return alpha | |
# ------------------------------------------------------------------------------ | |
# Functional | |
# ------------------------------------------------------------------------------ | |
def uniform_weights(x, x_mask): | |
"""Return uniform weights over non-masked x (a sequence of vectors). | |
Args: | |
x: batch * len * hdim | |
x_mask: batch * len (1 for padding, 0 for true) | |
Output: | |
x_avg: batch * hdim | |
""" | |
alpha = torch.ones(x.size(0), x.size(1)) | |
if x.data.is_cuda: | |
alpha = alpha.cuda() | |
alpha = alpha * x_mask.eq(0).float() | |
alpha = alpha / alpha.sum(1).expand(alpha.size()) | |
return alpha | |
def weighted_avg(x, weights): | |
"""Return a weighted average of x (a sequence of vectors). | |
Args: | |
x: batch * len * hdim | |
weights: batch * len, sum(dim = 1) = 1 | |
Output: | |
x_avg: batch * hdim | |
""" | |
return weights.unsqueeze(1).bmm(x).squeeze(1) | |