Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Implementation of the RNN based DrQA reader.""" | |
import torch | |
import torch.nn as nn | |
from . import layers | |
# ------------------------------------------------------------------------------ | |
# Network | |
# ------------------------------------------------------------------------------ | |
class RnnDocReader(nn.Module): | |
RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} | |
def __init__(self, args, normalize=True): | |
super(RnnDocReader, self).__init__() | |
# Store config | |
self.args = args | |
# Word embeddings (+1 for padding) | |
self.embedding = nn.Embedding(args.vocab_size, | |
args.embedding_dim, | |
padding_idx=0) | |
# Projection for attention weighted question | |
if args.use_qemb: | |
self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) | |
# Input size to RNN: word emb + question emb + manual features | |
doc_input_size = args.embedding_dim + args.num_features | |
if args.use_qemb: | |
doc_input_size += args.embedding_dim | |
# RNN document encoder | |
self.doc_rnn = layers.StackedBRNN( | |
input_size=doc_input_size, | |
hidden_size=args.hidden_size, | |
num_layers=args.doc_layers, | |
dropout_rate=args.dropout_rnn, | |
dropout_output=args.dropout_rnn_output, | |
concat_layers=args.concat_rnn_layers, | |
rnn_type=self.RNN_TYPES[args.rnn_type], | |
padding=args.rnn_padding, | |
) | |
# RNN question encoder | |
self.question_rnn = layers.StackedBRNN( | |
input_size=args.embedding_dim, | |
hidden_size=args.hidden_size, | |
num_layers=args.question_layers, | |
dropout_rate=args.dropout_rnn, | |
dropout_output=args.dropout_rnn_output, | |
concat_layers=args.concat_rnn_layers, | |
rnn_type=self.RNN_TYPES[args.rnn_type], | |
padding=args.rnn_padding, | |
) | |
# Output sizes of rnn encoders | |
doc_hidden_size = 2 * args.hidden_size | |
question_hidden_size = 2 * args.hidden_size | |
if args.concat_rnn_layers: | |
doc_hidden_size *= args.doc_layers | |
question_hidden_size *= args.question_layers | |
# Question merging | |
if args.question_merge not in ['avg', 'self_attn']: | |
raise NotImplementedError('merge_mode = %s' % args.merge_mode) | |
if args.question_merge == 'self_attn': | |
self.self_attn = layers.LinearSeqAttn(question_hidden_size) | |
# Bilinear attention for span start/end | |
self.start_attn = layers.BilinearSeqAttn( | |
doc_hidden_size, | |
question_hidden_size, | |
normalize=normalize, | |
) | |
self.end_attn = layers.BilinearSeqAttn( | |
doc_hidden_size, | |
question_hidden_size, | |
normalize=normalize, | |
) | |
def forward(self, x1, x1_f, x1_mask, x2, x2_mask): | |
"""Inputs: | |
x1 = document word indices [batch * len_d] | |
x1_f = document word features indices [batch * len_d * nfeat] | |
x1_mask = document padding mask [batch * len_d] | |
x2 = question word indices [batch * len_q] | |
x2_mask = question padding mask [batch * len_q] | |
""" | |
# Embed both document and question | |
x1_emb = self.embedding(x1) | |
x2_emb = self.embedding(x2) | |
# Dropout on embeddings | |
if self.args.dropout_emb > 0: | |
x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, | |
training=self.training) | |
x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, | |
training=self.training) | |
# Form document encoding inputs | |
drnn_input = [x1_emb] | |
# Add attention-weighted question representation | |
if self.args.use_qemb: | |
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) | |
drnn_input.append(x2_weighted_emb) | |
# Add manual features | |
if self.args.num_features > 0: | |
drnn_input.append(x1_f) | |
# Encode document with RNN | |
doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) | |
# Encode question with RNN + merge hiddens | |
question_hiddens = self.question_rnn(x2_emb, x2_mask) | |
if self.args.question_merge == 'avg': | |
q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) | |
elif self.args.question_merge == 'self_attn': | |
q_merge_weights = self.self_attn(question_hiddens, x2_mask) | |
question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) | |
# Predict start and end positions | |
start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask) | |
end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask) | |
return start_scores, end_scores | |