AVeriTeC-API / drqa /reader /rnn_reader.py
zhenyundeng
add files
e62781a
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Implementation of the RNN based DrQA reader."""
import torch
import torch.nn as nn
from . import layers
# ------------------------------------------------------------------------------
# Network
# ------------------------------------------------------------------------------
class RnnDocReader(nn.Module):
RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
def __init__(self, args, normalize=True):
super(RnnDocReader, self).__init__()
# Store config
self.args = args
# Word embeddings (+1 for padding)
self.embedding = nn.Embedding(args.vocab_size,
args.embedding_dim,
padding_idx=0)
# Projection for attention weighted question
if args.use_qemb:
self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
# Input size to RNN: word emb + question emb + manual features
doc_input_size = args.embedding_dim + args.num_features
if args.use_qemb:
doc_input_size += args.embedding_dim
# RNN document encoder
self.doc_rnn = layers.StackedBRNN(
input_size=doc_input_size,
hidden_size=args.hidden_size,
num_layers=args.doc_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# RNN question encoder
self.question_rnn = layers.StackedBRNN(
input_size=args.embedding_dim,
hidden_size=args.hidden_size,
num_layers=args.question_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# Output sizes of rnn encoders
doc_hidden_size = 2 * args.hidden_size
question_hidden_size = 2 * args.hidden_size
if args.concat_rnn_layers:
doc_hidden_size *= args.doc_layers
question_hidden_size *= args.question_layers
# Question merging
if args.question_merge not in ['avg', 'self_attn']:
raise NotImplementedError('merge_mode = %s' % args.merge_mode)
if args.question_merge == 'self_attn':
self.self_attn = layers.LinearSeqAttn(question_hidden_size)
# Bilinear attention for span start/end
self.start_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
self.end_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
"""Inputs:
x1 = document word indices [batch * len_d]
x1_f = document word features indices [batch * len_d * nfeat]
x1_mask = document padding mask [batch * len_d]
x2 = question word indices [batch * len_q]
x2_mask = question padding mask [batch * len_q]
"""
# Embed both document and question
x1_emb = self.embedding(x1)
x2_emb = self.embedding(x2)
# Dropout on embeddings
if self.args.dropout_emb > 0:
x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
training=self.training)
x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
training=self.training)
# Form document encoding inputs
drnn_input = [x1_emb]
# Add attention-weighted question representation
if self.args.use_qemb:
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
drnn_input.append(x2_weighted_emb)
# Add manual features
if self.args.num_features > 0:
drnn_input.append(x1_f)
# Encode document with RNN
doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
# Encode question with RNN + merge hiddens
question_hiddens = self.question_rnn(x2_emb, x2_mask)
if self.args.question_merge == 'avg':
q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
elif self.args.question_merge == 'self_attn':
q_merge_weights = self.self_attn(question_hiddens, x2_mask)
question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
# Predict start and end positions
start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
return start_scores, end_scores