zhenyundeng
add files
e62781a
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Functions for putting examples into torch format."""
from collections import Counter
import torch
def vectorize(ex, model, single_answer=False):
"""Torchify a single example."""
args = model.args
word_dict = model.word_dict
feature_dict = model.feature_dict
# Index words
document = torch.LongTensor([word_dict[w] for w in ex['document']])
question = torch.LongTensor([word_dict[w] for w in ex['question']])
# Create extra features vector
if len(feature_dict) > 0:
features = torch.zeros(len(ex['document']), len(feature_dict))
else:
features = None
# f_{exact_match}
if args.use_in_question:
q_words_cased = {w for w in ex['question']}
q_words_uncased = {w.lower() for w in ex['question']}
q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None
for i in range(len(ex['document'])):
if ex['document'][i] in q_words_cased:
features[i][feature_dict['in_question']] = 1.0
if ex['document'][i].lower() in q_words_uncased:
features[i][feature_dict['in_question_uncased']] = 1.0
if q_lemma and ex['lemma'][i] in q_lemma:
features[i][feature_dict['in_question_lemma']] = 1.0
# f_{token} (POS)
if args.use_pos:
for i, w in enumerate(ex['pos']):
f = 'pos=%s' % w
if f in feature_dict:
features[i][feature_dict[f]] = 1.0
# f_{token} (NER)
if args.use_ner:
for i, w in enumerate(ex['ner']):
f = 'ner=%s' % w
if f in feature_dict:
features[i][feature_dict[f]] = 1.0
# f_{token} (TF)
if args.use_tf:
counter = Counter([w.lower() for w in ex['document']])
l = len(ex['document'])
for i, w in enumerate(ex['document']):
features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l
# Maybe return without target
if 'answers' not in ex:
return document, features, question, ex['id']
# ...or with target(s) (might still be empty if answers is empty)
if single_answer:
assert(len(ex['answers']) > 0)
start = torch.LongTensor(1).fill_(ex['answers'][0][0])
end = torch.LongTensor(1).fill_(ex['answers'][0][1])
else:
start = [a[0] for a in ex['answers']]
end = [a[1] for a in ex['answers']]
return document, features, question, start, end, ex['id']
def batchify(batch):
"""Gather a batch of individual examples into one batch."""
NUM_INPUTS = 3
NUM_TARGETS = 2
NUM_EXTRA = 1
ids = [ex[-1] for ex in batch]
docs = [ex[0] for ex in batch]
features = [ex[1] for ex in batch]
questions = [ex[2] for ex in batch]
# Batch documents and features
max_length = max([d.size(0) for d in docs])
x1 = torch.LongTensor(len(docs), max_length).zero_()
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
if features[0] is None:
x1_f = None
else:
x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
for i, d in enumerate(docs):
x1[i, :d.size(0)].copy_(d)
x1_mask[i, :d.size(0)].fill_(0)
if x1_f is not None:
x1_f[i, :d.size(0)].copy_(features[i])
# Batch questions
max_length = max([q.size(0) for q in questions])
x2 = torch.LongTensor(len(questions), max_length).zero_()
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
for i, q in enumerate(questions):
x2[i, :q.size(0)].copy_(q)
x2_mask[i, :q.size(0)].fill_(0)
# Maybe return without targets
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
return x1, x1_f, x1_mask, x2, x2_mask, ids
elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
# ...Otherwise add targets
if torch.is_tensor(batch[0][3]):
y_s = torch.cat([ex[3] for ex in batch])
y_e = torch.cat([ex[4] for ex in batch])
else:
y_s = [ex[3] for ex in batch]
y_e = [ex[4] for ex in batch]
else:
raise RuntimeError('Incorrect number of inputs per example.')
return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids