Spaces:
Build error
Build error
File size: 5,168 Bytes
e62781a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Implementation of the RNN based DrQA reader."""
import torch
import torch.nn as nn
from . import layers
# ------------------------------------------------------------------------------
# Network
# ------------------------------------------------------------------------------
class RnnDocReader(nn.Module):
RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
def __init__(self, args, normalize=True):
super(RnnDocReader, self).__init__()
# Store config
self.args = args
# Word embeddings (+1 for padding)
self.embedding = nn.Embedding(args.vocab_size,
args.embedding_dim,
padding_idx=0)
# Projection for attention weighted question
if args.use_qemb:
self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
# Input size to RNN: word emb + question emb + manual features
doc_input_size = args.embedding_dim + args.num_features
if args.use_qemb:
doc_input_size += args.embedding_dim
# RNN document encoder
self.doc_rnn = layers.StackedBRNN(
input_size=doc_input_size,
hidden_size=args.hidden_size,
num_layers=args.doc_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# RNN question encoder
self.question_rnn = layers.StackedBRNN(
input_size=args.embedding_dim,
hidden_size=args.hidden_size,
num_layers=args.question_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# Output sizes of rnn encoders
doc_hidden_size = 2 * args.hidden_size
question_hidden_size = 2 * args.hidden_size
if args.concat_rnn_layers:
doc_hidden_size *= args.doc_layers
question_hidden_size *= args.question_layers
# Question merging
if args.question_merge not in ['avg', 'self_attn']:
raise NotImplementedError('merge_mode = %s' % args.merge_mode)
if args.question_merge == 'self_attn':
self.self_attn = layers.LinearSeqAttn(question_hidden_size)
# Bilinear attention for span start/end
self.start_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
self.end_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
def forward(self, x1, x1_f, x1_mask, x2, x2_mask):
"""Inputs:
x1 = document word indices [batch * len_d]
x1_f = document word features indices [batch * len_d * nfeat]
x1_mask = document padding mask [batch * len_d]
x2 = question word indices [batch * len_q]
x2_mask = question padding mask [batch * len_q]
"""
# Embed both document and question
x1_emb = self.embedding(x1)
x2_emb = self.embedding(x2)
# Dropout on embeddings
if self.args.dropout_emb > 0:
x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
training=self.training)
x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
training=self.training)
# Form document encoding inputs
drnn_input = [x1_emb]
# Add attention-weighted question representation
if self.args.use_qemb:
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
drnn_input.append(x2_weighted_emb)
# Add manual features
if self.args.num_features > 0:
drnn_input.append(x1_f)
# Encode document with RNN
doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
# Encode question with RNN + merge hiddens
question_hiddens = self.question_rnn(x2_emb, x2_mask)
if self.args.question_merge == 'avg':
q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
elif self.args.question_merge == 'self_attn':
q_merge_weights = self.self_attn(question_hiddens, x2_mask)
question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
# Predict start and end positions
start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
return start_scores, end_scores
|