NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
"""
from __future__ import print_function
import copy
import numbers
import numpy as np
import wiki_data
def return_index(a):
for i in range(len(a)):
if (a[i] == 1.0):
return i
def construct_vocab(data, utility, add_word=False):
ans = []
for example in data:
sent = ""
for word in example.question:
if (not (isinstance(word, numbers.Number))):
sent += word + " "
example.original_nc = copy.deepcopy(example.number_columns)
example.original_wc = copy.deepcopy(example.word_columns)
example.original_nc_names = copy.deepcopy(example.number_column_names)
example.original_wc_names = copy.deepcopy(example.word_column_names)
if (add_word):
continue
number_found = 0
if (not (example.is_bad_example)):
for word in example.question:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
for col_name in example.word_column_names:
for word in col_name:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
for col_name in example.number_column_names:
for word in col_name:
if (isinstance(word, numbers.Number)):
number_found += 1
else:
if (not (utility.word_ids.has_key(word))):
utility.words.append(word)
utility.word_count[word] = 1
utility.word_ids[word] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[word]] = word
else:
utility.word_count[word] += 1
def word_lookup(word, utility):
if (utility.word_ids.has_key(word)):
return word
else:
return utility.unk_token
def convert_to_int_2d_and_pad(a, utility):
ans = []
#print a
for b in a:
temp = []
if (len(b) > utility.FLAGS.max_entry_length):
b = b[0:utility.FLAGS.max_entry_length]
for remaining in range(len(b), utility.FLAGS.max_entry_length):
b.append(utility.dummy_token)
assert len(b) == utility.FLAGS.max_entry_length
for word in b:
temp.append(utility.word_ids[word_lookup(word, utility)])
ans.append(temp)
#print ans
return ans
def convert_to_bool_and_pad(a, utility):
a = a.tolist()
for i in range(len(a)):
for j in range(len(a[i])):
if (a[i][j] < 1):
a[i][j] = False
else:
a[i][j] = True
a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
return a
seen_tables = {}
def partial_match(question, table, number):
answer = []
match = {}
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
for j in range(len(table[i])):
for word in question:
if (number):
if (word == table[i][j]):
answer[i][j] = 1.0
match[i] = 1.0
else:
if (word in table[i][j]):
answer[i][j] = 1.0
match[i] = 1.0
return answer, match
def exact_match(question, table, number):
#performs exact match operation
answer = []
match = {}
matched_indices = []
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
for j in range(len(table[i])):
if (number):
for word in question:
if (word == table[i][j]):
match[i] = 1.0
answer[i][j] = 1.0
else:
table_entry = table[i][j]
for k in range(len(question)):
if (k + len(table_entry) <= len(question)):
if (table_entry == question[k:(k + len(table_entry))]):
#if(len(table_entry) == 1):
#print "match: ", table_entry, question
match[i] = 1.0
answer[i][j] = 1.0
matched_indices.append((k, len(table_entry)))
return answer, match, matched_indices
def partial_column_match(question, table, number):
answer = []
for i in range(len(table)):
answer.append(0)
for i in range(len(table)):
for word in question:
if (word in table[i]):
answer[i] = 1.0
return answer
def exact_column_match(question, table, number):
#performs exact match on column names
answer = []
matched_indices = []
for i in range(len(table)):
answer.append(0)
for i in range(len(table)):
table_entry = table[i]
for k in range(len(question)):
if (k + len(table_entry) <= len(question)):
if (table_entry == question[k:(k + len(table_entry))]):
answer[i] = 1.0
matched_indices.append((k, len(table_entry)))
return answer, matched_indices
def get_max_entry(a):
e = {}
for w in a:
if (w != "UNK, "):
if (e.has_key(w)):
e[w] += 1
else:
e[w] = 1
if (len(e) > 0):
(key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
if (val > 1):
return key
else:
return -1.0
else:
return -1.0
def list_join(a):
ans = ""
for w in a:
ans += str(w) + ", "
return ans
def group_by_max(table, number):
#computes the most frequently occurring entry in a column
answer = []
for i in range(len(table)):
temp = []
for j in range(len(table[i])):
temp.append(0)
answer.append(temp)
for i in range(len(table)):
if (number):
curr = table[i]
else:
curr = [list_join(w) for w in table[i]]
max_entry = get_max_entry(curr)
#print i, max_entry
for j in range(len(curr)):
if (max_entry == curr[j]):
answer[i][j] = 1.0
else:
answer[i][j] = 0.0
return answer
def pick_one(a):
for i in range(len(a)):
if (1.0 in a[i]):
return True
return False
def check_processed_cols(col, utility):
return True in [
True for y in col
if (y != utility.FLAGS.pad_int and y !=
utility.FLAGS.bad_number_pre_process)
]
def complete_wiki_processing(data, utility, train=True):
#convert to integers and padding
processed_data = []
num_bad_examples = 0
for example in data:
number_found = 0
if (example.is_bad_example):
num_bad_examples += 1
if (not (example.is_bad_example)):
example.string_question = example.question[:]
#entry match
example.processed_number_columns = example.processed_number_columns[:]
example.processed_word_columns = example.processed_word_columns[:]
example.word_exact_match, word_match, matched_indices = exact_match(
example.string_question, example.original_wc, number=False)
example.number_exact_match, number_match, _ = exact_match(
example.string_question, example.original_nc, number=True)
if (not (pick_one(example.word_exact_match)) and not (
pick_one(example.number_exact_match))):
assert len(word_match) == 0
assert len(number_match) == 0
example.word_exact_match, word_match = partial_match(
example.string_question, example.original_wc, number=False)
#group by max
example.word_group_by_max = group_by_max(example.original_wc, False)
example.number_group_by_max = group_by_max(example.original_nc, True)
#column name match
example.word_column_exact_match, wcol_matched_indices = exact_column_match(
example.string_question, example.original_wc_names, number=False)
example.number_column_exact_match, ncol_matched_indices = exact_column_match(
example.string_question, example.original_nc_names, number=False)
if (not (1.0 in example.word_column_exact_match) and not (
1.0 in example.number_column_exact_match)):
example.word_column_exact_match = partial_column_match(
example.string_question, example.original_wc_names, number=False)
example.number_column_exact_match = partial_column_match(
example.string_question, example.original_nc_names, number=False)
if (len(word_match) > 0 or len(number_match) > 0):
example.question.append(utility.entry_match_token)
if (1.0 in example.word_column_exact_match or
1.0 in example.number_column_exact_match):
example.question.append(utility.column_match_token)
example.string_question = example.question[:]
example.number_lookup_matrix = np.transpose(
example.number_lookup_matrix)[:]
example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
example.columns = example.number_columns[:]
example.word_columns = example.word_columns[:]
example.len_total_cols = len(example.word_column_names) + len(
example.number_column_names)
example.column_names = example.number_column_names[:]
example.word_column_names = example.word_column_names[:]
example.string_column_names = example.number_column_names[:]
example.string_word_column_names = example.word_column_names[:]
example.sorted_number_index = []
example.sorted_word_index = []
example.column_mask = []
example.word_column_mask = []
example.processed_column_mask = []
example.processed_word_column_mask = []
example.word_column_entry_mask = []
example.question_attention_mask = []
example.question_number = example.question_number_1 = -1
example.question_attention_mask = []
example.ordinal_question = []
example.ordinal_question_one = []
new_question = []
if (len(example.number_columns) > 0):
example.len_col = len(example.number_columns[0])
else:
example.len_col = len(example.word_columns[0])
for (start, length) in matched_indices:
for j in range(length):
example.question[start + j] = utility.unk_token
#print example.question
for word in example.question:
if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
if (not (isinstance(word, numbers.Number)) and
wiki_data.is_date(word)):
word = word.replace("X", "").replace("-", "")
number_found += 1
if (number_found == 1):
example.question_number = word
if (len(example.ordinal_question) > 0):
example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
else:
example.ordinal_question.append(1.0)
elif (number_found == 2):
example.question_number_1 = word
if (len(example.ordinal_question_one) > 0):
example.ordinal_question_one[len(example.ordinal_question_one) -
1] = 1.0
else:
example.ordinal_question_one.append(1.0)
else:
new_question.append(word)
example.ordinal_question.append(0.0)
example.ordinal_question_one.append(0.0)
example.question = [
utility.word_ids[word_lookup(w, utility)] for w in new_question
]
example.question_attention_mask = [0.0] * len(example.question)
#when the first question number occurs before a word
example.ordinal_question = example.ordinal_question[0:len(
example.question)]
example.ordinal_question_one = example.ordinal_question_one[0:len(
example.question)]
#question-padding
example.question = [utility.word_ids[utility.dummy_token]] * (
utility.FLAGS.question_length - len(example.question)
) + example.question
example.question_attention_mask = [-10000.0] * (
utility.FLAGS.question_length - len(example.question_attention_mask)
) + example.question_attention_mask
example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
len(example.ordinal_question)
) + example.ordinal_question
example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
len(example.ordinal_question_one)
) + example.ordinal_question_one
if (True):
#number columns and related-padding
num_cols = len(example.columns)
start = 0
for column in example.number_columns:
if (check_processed_cols(example.processed_number_columns[start],
utility)):
example.processed_column_mask.append(0.0)
sorted_index = sorted(
range(len(example.processed_number_columns[start])),
key=lambda k: example.processed_number_columns[start][k],
reverse=True)
sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(sorted_index))
example.sorted_number_index.append(sorted_index)
example.columns[start] = column + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(column))
example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements -
len(example.processed_number_columns[start]))
start += 1
example.column_mask.append(0.0)
for remaining in range(num_cols, utility.FLAGS.max_number_cols):
example.sorted_number_index.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.processed_number_columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.number_exact_match.append([0.0] *
(utility.FLAGS.max_elements))
example.number_group_by_max.append([0.0] *
(utility.FLAGS.max_elements))
example.column_mask.append(-100000000.0)
example.processed_column_mask.append(-100000000.0)
example.number_column_exact_match.append(0.0)
example.column_names.append([utility.dummy_token])
#word column and related-padding
start = 0
word_num_cols = len(example.word_columns)
for column in example.word_columns:
if (check_processed_cols(example.processed_word_columns[start],
utility)):
example.processed_word_column_mask.append(0.0)
sorted_index = sorted(
range(len(example.processed_word_columns[start])),
key=lambda k: example.processed_word_columns[start][k],
reverse=True)
sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements - len(sorted_index))
example.sorted_word_index.append(sorted_index)
column = convert_to_int_2d_and_pad(column, utility)
example.word_columns[start] = column + [[
utility.word_ids[utility.dummy_token]
] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
len(column))
example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
utility.FLAGS.max_elements -
len(example.processed_word_columns[start]))
example.word_column_entry_mask.append([0] * len(column) + [
utility.word_ids[utility.dummy_token]
] * (utility.FLAGS.max_elements - len(column)))
start += 1
example.word_column_mask.append(0.0)
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
example.sorted_word_index.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
utility.FLAGS.max_entry_length] *
(utility.FLAGS.max_elements))
example.word_column_entry_mask.append(
[utility.word_ids[utility.dummy_token]] *
(utility.FLAGS.max_elements))
example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
example.processed_word_columns.append([utility.FLAGS.pad_int] *
(utility.FLAGS.max_elements))
example.word_column_mask.append(-100000000.0)
example.processed_word_column_mask.append(-100000000.0)
example.word_column_exact_match.append(0.0)
example.word_column_names.append([utility.dummy_token] *
utility.FLAGS.max_entry_length)
seen_tables[example.table_key] = 1
#convert column and word column names to integers
example.column_ids = convert_to_int_2d_and_pad(example.column_names,
utility)
example.word_column_ids = convert_to_int_2d_and_pad(
example.word_column_names, utility)
for i_em in range(len(example.number_exact_match)):
example.number_exact_match[i_em] = example.number_exact_match[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_exact_match[i_em]))
example.number_group_by_max[i_em] = example.number_group_by_max[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_group_by_max[i_em]))
for i_em in range(len(example.word_exact_match)):
example.word_exact_match[i_em] = example.word_exact_match[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_exact_match[i_em]))
example.word_group_by_max[i_em] = example.word_group_by_max[
i_em] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_group_by_max[i_em]))
example.exact_match = example.number_exact_match + example.word_exact_match
example.group_by_max = example.number_group_by_max + example.word_group_by_max
example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
#answer and related mask, padding
if (example.is_lookup):
example.answer = example.calc_answer
example.number_print_answer = example.number_lookup_matrix.tolist()
example.word_print_answer = example.word_lookup_matrix.tolist()
for i_answer in range(len(example.number_print_answer)):
example.number_print_answer[i_answer] = example.number_print_answer[
i_answer] + [0.0] * (utility.FLAGS.max_elements -
len(example.number_print_answer[i_answer]))
for i_answer in range(len(example.word_print_answer)):
example.word_print_answer[i_answer] = example.word_print_answer[
i_answer] + [0.0] * (utility.FLAGS.max_elements -
len(example.word_print_answer[i_answer]))
example.number_lookup_matrix = convert_to_bool_and_pad(
example.number_lookup_matrix, utility)
example.word_lookup_matrix = convert_to_bool_and_pad(
example.word_lookup_matrix, utility)
for remaining in range(num_cols, utility.FLAGS.max_number_cols):
example.number_lookup_matrix.append([False] *
utility.FLAGS.max_elements)
example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
example.word_lookup_matrix.append([False] *
utility.FLAGS.max_elements)
example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
example.print_answer = example.number_print_answer + example.word_print_answer
else:
example.answer = example.calc_answer
example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
#question_number masks
if (example.question_number == -1):
example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
else:
example.question_number_mask = np.ones([utility.FLAGS.max_elements])
if (example.question_number_1 == -1):
example.question_number_one_mask = -10000.0
else:
example.question_number_one_mask = np.float64(0.0)
if (example.len_col > utility.FLAGS.max_elements):
continue
processed_data.append(example)
return processed_data
def add_special_words(utility):
utility.words.append(utility.entry_match_token)
utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.entry_match_token]] = utility.entry_match_token
utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
print("entry match token: ", utility.word_ids[
utility.entry_match_token], utility.entry_match_token_id)
utility.words.append(utility.column_match_token)
utility.word_ids[utility.column_match_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.column_match_token]] = utility.column_match_token
utility.column_match_token_id = utility.word_ids[utility.column_match_token]
print("entry match token: ", utility.word_ids[
utility.column_match_token], utility.column_match_token_id)
utility.words.append(utility.dummy_token)
utility.word_ids[utility.dummy_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.dummy_token]] = utility.dummy_token
utility.dummy_token_id = utility.word_ids[utility.dummy_token]
utility.words.append(utility.unk_token)
utility.word_ids[utility.unk_token] = len(utility.word_ids)
utility.reverse_word_ids[utility.word_ids[
utility.unk_token]] = utility.unk_token
def perform_word_cutoff(utility):
if (utility.FLAGS.word_cutoff > 0):
for word in utility.word_ids.keys():
if (utility.word_count.has_key(word) and utility.word_count[word] <
utility.FLAGS.word_cutoff and word != utility.unk_token and
word != utility.dummy_token and word != utility.entry_match_token and
word != utility.column_match_token):
utility.word_ids.pop(word)
utility.words.remove(word)
def word_dropout(question, utility):
if (utility.FLAGS.word_dropout_prob > 0.0):
new_question = []
for i in range(len(question)):
if (question[i] != utility.dummy_token_id and
utility.random.random() > utility.FLAGS.word_dropout_prob):
new_question.append(utility.word_ids[utility.unk_token])
else:
new_question.append(question[i])
return new_question
else:
return question
def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
#prepare feed dict dictionary
feed_dict = {}
feed_examples = []
for j in range(batch_size):
feed_examples.append(data[curr + j])
if (train):
feed_dict[gr.batch_question] = [
word_dropout(feed_examples[j].question, utility)
for j in range(batch_size)
]
else:
feed_dict[gr.batch_question] = [
feed_examples[j].question for j in range(batch_size)
]
feed_dict[gr.batch_question_attention_mask] = [
feed_examples[j].question_attention_mask for j in range(batch_size)
]
feed_dict[
gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
feed_dict[gr.batch_number_column] = [
feed_examples[j].columns for j in range(batch_size)
]
feed_dict[gr.batch_processed_number_column] = [
feed_examples[j].processed_number_columns for j in range(batch_size)
]
feed_dict[gr.batch_processed_sorted_index_number_column] = [
feed_examples[j].sorted_number_index for j in range(batch_size)
]
feed_dict[gr.batch_processed_sorted_index_word_column] = [
feed_examples[j].sorted_word_index for j in range(batch_size)
]
feed_dict[gr.batch_question_number] = np.array(
[feed_examples[j].question_number for j in range(batch_size)]).reshape(
(batch_size, 1))
feed_dict[gr.batch_question_number_one] = np.array(
[feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
(batch_size, 1))
feed_dict[gr.batch_question_number_mask] = [
feed_examples[j].question_number_mask for j in range(batch_size)
]
feed_dict[gr.batch_question_number_one_mask] = np.array(
[feed_examples[j].question_number_one_mask for j in range(batch_size)
]).reshape((batch_size, 1))
feed_dict[gr.batch_print_answer] = [
feed_examples[j].print_answer for j in range(batch_size)
]
feed_dict[gr.batch_exact_match] = [
feed_examples[j].exact_match for j in range(batch_size)
]
feed_dict[gr.batch_group_by_max] = [
feed_examples[j].group_by_max for j in range(batch_size)
]
feed_dict[gr.batch_column_exact_match] = [
feed_examples[j].exact_column_match for j in range(batch_size)
]
feed_dict[gr.batch_ordinal_question] = [
feed_examples[j].ordinal_question for j in range(batch_size)
]
feed_dict[gr.batch_ordinal_question_one] = [
feed_examples[j].ordinal_question_one for j in range(batch_size)
]
feed_dict[gr.batch_number_column_mask] = [
feed_examples[j].column_mask for j in range(batch_size)
]
feed_dict[gr.batch_number_column_names] = [
feed_examples[j].column_ids for j in range(batch_size)
]
feed_dict[gr.batch_processed_word_column] = [
feed_examples[j].processed_word_columns for j in range(batch_size)
]
feed_dict[gr.batch_word_column_mask] = [
feed_examples[j].word_column_mask for j in range(batch_size)
]
feed_dict[gr.batch_word_column_names] = [
feed_examples[j].word_column_ids for j in range(batch_size)
]
feed_dict[gr.batch_word_column_entry_mask] = [
feed_examples[j].word_column_entry_mask for j in range(batch_size)
]
return feed_dict