Spaces:
Build error
Build error
File size: 4,422 Bytes
e62781a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for putting examples into torch format."""
from collections import Counter
import torch
def vectorize(ex, model, single_answer=False):
"""Torchify a single example."""
args = model.args
word_dict = model.word_dict
feature_dict = model.feature_dict
# Index words
document = torch.LongTensor([word_dict[w] for w in ex['document']])
question = torch.LongTensor([word_dict[w] for w in ex['question']])
# Create extra features vector
if len(feature_dict) > 0:
features = torch.zeros(len(ex['document']), len(feature_dict))
else:
features = None
# f_{exact_match}
if args.use_in_question:
q_words_cased = {w for w in ex['question']}
q_words_uncased = {w.lower() for w in ex['question']}
q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None
for i in range(len(ex['document'])):
if ex['document'][i] in q_words_cased:
features[i][feature_dict['in_question']] = 1.0
if ex['document'][i].lower() in q_words_uncased:
features[i][feature_dict['in_question_uncased']] = 1.0
if q_lemma and ex['lemma'][i] in q_lemma:
features[i][feature_dict['in_question_lemma']] = 1.0
# f_{token} (POS)
if args.use_pos:
for i, w in enumerate(ex['pos']):
f = 'pos=%s' % w
if f in feature_dict:
features[i][feature_dict[f]] = 1.0
# f_{token} (NER)
if args.use_ner:
for i, w in enumerate(ex['ner']):
f = 'ner=%s' % w
if f in feature_dict:
features[i][feature_dict[f]] = 1.0
# f_{token} (TF)
if args.use_tf:
counter = Counter([w.lower() for w in ex['document']])
l = len(ex['document'])
for i, w in enumerate(ex['document']):
features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l
# Maybe return without target
if 'answers' not in ex:
return document, features, question, ex['id']
# ...or with target(s) (might still be empty if answers is empty)
if single_answer:
assert(len(ex['answers']) > 0)
start = torch.LongTensor(1).fill_(ex['answers'][0][0])
end = torch.LongTensor(1).fill_(ex['answers'][0][1])
else:
start = [a[0] for a in ex['answers']]
end = [a[1] for a in ex['answers']]
return document, features, question, start, end, ex['id']
def batchify(batch):
"""Gather a batch of individual examples into one batch."""
NUM_INPUTS = 3
NUM_TARGETS = 2
NUM_EXTRA = 1
ids = [ex[-1] for ex in batch]
docs = [ex[0] for ex in batch]
features = [ex[1] for ex in batch]
questions = [ex[2] for ex in batch]
# Batch documents and features
max_length = max([d.size(0) for d in docs])
x1 = torch.LongTensor(len(docs), max_length).zero_()
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
if features[0] is None:
x1_f = None
else:
x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
for i, d in enumerate(docs):
x1[i, :d.size(0)].copy_(d)
x1_mask[i, :d.size(0)].fill_(0)
if x1_f is not None:
x1_f[i, :d.size(0)].copy_(features[i])
# Batch questions
max_length = max([q.size(0) for q in questions])
x2 = torch.LongTensor(len(questions), max_length).zero_()
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
for i, q in enumerate(questions):
x2[i, :q.size(0)].copy_(q)
x2_mask[i, :q.size(0)].fill_(0)
# Maybe return without targets
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
return x1, x1_f, x1_mask, x2, x2_mask, ids
elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
# ...Otherwise add targets
if torch.is_tensor(batch[0][3]):
y_s = torch.cat([ex[3] for ex in batch])
y_e = torch.cat([ex[4] for ex in batch])
else:
y_s = [ex[3] for ex in batch]
y_e = [ex[4] for ex in batch]
else:
raise RuntimeError('Incorrect number of inputs per example.')
return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids
|