Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Functions for putting examples into torch format.""" | |
from collections import Counter | |
import torch | |
def vectorize(ex, model, single_answer=False): | |
"""Torchify a single example.""" | |
args = model.args | |
word_dict = model.word_dict | |
feature_dict = model.feature_dict | |
# Index words | |
document = torch.LongTensor([word_dict[w] for w in ex['document']]) | |
question = torch.LongTensor([word_dict[w] for w in ex['question']]) | |
# Create extra features vector | |
if len(feature_dict) > 0: | |
features = torch.zeros(len(ex['document']), len(feature_dict)) | |
else: | |
features = None | |
# f_{exact_match} | |
if args.use_in_question: | |
q_words_cased = {w for w in ex['question']} | |
q_words_uncased = {w.lower() for w in ex['question']} | |
q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None | |
for i in range(len(ex['document'])): | |
if ex['document'][i] in q_words_cased: | |
features[i][feature_dict['in_question']] = 1.0 | |
if ex['document'][i].lower() in q_words_uncased: | |
features[i][feature_dict['in_question_uncased']] = 1.0 | |
if q_lemma and ex['lemma'][i] in q_lemma: | |
features[i][feature_dict['in_question_lemma']] = 1.0 | |
# f_{token} (POS) | |
if args.use_pos: | |
for i, w in enumerate(ex['pos']): | |
f = 'pos=%s' % w | |
if f in feature_dict: | |
features[i][feature_dict[f]] = 1.0 | |
# f_{token} (NER) | |
if args.use_ner: | |
for i, w in enumerate(ex['ner']): | |
f = 'ner=%s' % w | |
if f in feature_dict: | |
features[i][feature_dict[f]] = 1.0 | |
# f_{token} (TF) | |
if args.use_tf: | |
counter = Counter([w.lower() for w in ex['document']]) | |
l = len(ex['document']) | |
for i, w in enumerate(ex['document']): | |
features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l | |
# Maybe return without target | |
if 'answers' not in ex: | |
return document, features, question, ex['id'] | |
# ...or with target(s) (might still be empty if answers is empty) | |
if single_answer: | |
assert(len(ex['answers']) > 0) | |
start = torch.LongTensor(1).fill_(ex['answers'][0][0]) | |
end = torch.LongTensor(1).fill_(ex['answers'][0][1]) | |
else: | |
start = [a[0] for a in ex['answers']] | |
end = [a[1] for a in ex['answers']] | |
return document, features, question, start, end, ex['id'] | |
def batchify(batch): | |
"""Gather a batch of individual examples into one batch.""" | |
NUM_INPUTS = 3 | |
NUM_TARGETS = 2 | |
NUM_EXTRA = 1 | |
ids = [ex[-1] for ex in batch] | |
docs = [ex[0] for ex in batch] | |
features = [ex[1] for ex in batch] | |
questions = [ex[2] for ex in batch] | |
# Batch documents and features | |
max_length = max([d.size(0) for d in docs]) | |
x1 = torch.LongTensor(len(docs), max_length).zero_() | |
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) | |
if features[0] is None: | |
x1_f = None | |
else: | |
x1_f = torch.zeros(len(docs), max_length, features[0].size(1)) | |
for i, d in enumerate(docs): | |
x1[i, :d.size(0)].copy_(d) | |
x1_mask[i, :d.size(0)].fill_(0) | |
if x1_f is not None: | |
x1_f[i, :d.size(0)].copy_(features[i]) | |
# Batch questions | |
max_length = max([q.size(0) for q in questions]) | |
x2 = torch.LongTensor(len(questions), max_length).zero_() | |
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) | |
for i, q in enumerate(questions): | |
x2[i, :q.size(0)].copy_(q) | |
x2_mask[i, :q.size(0)].fill_(0) | |
# Maybe return without targets | |
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: | |
return x1, x1_f, x1_mask, x2, x2_mask, ids | |
elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: | |
# ...Otherwise add targets | |
if torch.is_tensor(batch[0][3]): | |
y_s = torch.cat([ex[3] for ex in batch]) | |
y_e = torch.cat([ex[4] for ex in batch]) | |
else: | |
y_s = [ex[3] for ex in batch] | |
y_e = [ex[4] for ex in batch] | |
else: | |
raise RuntimeError('Incorrect number of inputs per example.') | |
return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids | |