|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan) |
|
""" |
|
|
|
from __future__ import print_function |
|
|
|
import copy |
|
import numbers |
|
import numpy as np |
|
import wiki_data |
|
|
|
|
|
def return_index(a): |
|
for i in range(len(a)): |
|
if (a[i] == 1.0): |
|
return i |
|
|
|
|
|
def construct_vocab(data, utility, add_word=False): |
|
ans = [] |
|
for example in data: |
|
sent = "" |
|
for word in example.question: |
|
if (not (isinstance(word, numbers.Number))): |
|
sent += word + " " |
|
example.original_nc = copy.deepcopy(example.number_columns) |
|
example.original_wc = copy.deepcopy(example.word_columns) |
|
example.original_nc_names = copy.deepcopy(example.number_column_names) |
|
example.original_wc_names = copy.deepcopy(example.word_column_names) |
|
if (add_word): |
|
continue |
|
number_found = 0 |
|
if (not (example.is_bad_example)): |
|
for word in example.question: |
|
if (isinstance(word, numbers.Number)): |
|
number_found += 1 |
|
else: |
|
if (not (utility.word_ids.has_key(word))): |
|
utility.words.append(word) |
|
utility.word_count[word] = 1 |
|
utility.word_ids[word] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[word]] = word |
|
else: |
|
utility.word_count[word] += 1 |
|
for col_name in example.word_column_names: |
|
for word in col_name: |
|
if (isinstance(word, numbers.Number)): |
|
number_found += 1 |
|
else: |
|
if (not (utility.word_ids.has_key(word))): |
|
utility.words.append(word) |
|
utility.word_count[word] = 1 |
|
utility.word_ids[word] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[word]] = word |
|
else: |
|
utility.word_count[word] += 1 |
|
for col_name in example.number_column_names: |
|
for word in col_name: |
|
if (isinstance(word, numbers.Number)): |
|
number_found += 1 |
|
else: |
|
if (not (utility.word_ids.has_key(word))): |
|
utility.words.append(word) |
|
utility.word_count[word] = 1 |
|
utility.word_ids[word] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[word]] = word |
|
else: |
|
utility.word_count[word] += 1 |
|
|
|
|
|
def word_lookup(word, utility): |
|
if (utility.word_ids.has_key(word)): |
|
return word |
|
else: |
|
return utility.unk_token |
|
|
|
|
|
def convert_to_int_2d_and_pad(a, utility): |
|
ans = [] |
|
|
|
for b in a: |
|
temp = [] |
|
if (len(b) > utility.FLAGS.max_entry_length): |
|
b = b[0:utility.FLAGS.max_entry_length] |
|
for remaining in range(len(b), utility.FLAGS.max_entry_length): |
|
b.append(utility.dummy_token) |
|
assert len(b) == utility.FLAGS.max_entry_length |
|
for word in b: |
|
temp.append(utility.word_ids[word_lookup(word, utility)]) |
|
ans.append(temp) |
|
|
|
return ans |
|
|
|
|
|
def convert_to_bool_and_pad(a, utility): |
|
a = a.tolist() |
|
for i in range(len(a)): |
|
for j in range(len(a[i])): |
|
if (a[i][j] < 1): |
|
a[i][j] = False |
|
else: |
|
a[i][j] = True |
|
a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i])) |
|
return a |
|
|
|
|
|
seen_tables = {} |
|
|
|
|
|
def partial_match(question, table, number): |
|
answer = [] |
|
match = {} |
|
for i in range(len(table)): |
|
temp = [] |
|
for j in range(len(table[i])): |
|
temp.append(0) |
|
answer.append(temp) |
|
for i in range(len(table)): |
|
for j in range(len(table[i])): |
|
for word in question: |
|
if (number): |
|
if (word == table[i][j]): |
|
answer[i][j] = 1.0 |
|
match[i] = 1.0 |
|
else: |
|
if (word in table[i][j]): |
|
answer[i][j] = 1.0 |
|
match[i] = 1.0 |
|
return answer, match |
|
|
|
|
|
def exact_match(question, table, number): |
|
|
|
answer = [] |
|
match = {} |
|
matched_indices = [] |
|
for i in range(len(table)): |
|
temp = [] |
|
for j in range(len(table[i])): |
|
temp.append(0) |
|
answer.append(temp) |
|
for i in range(len(table)): |
|
for j in range(len(table[i])): |
|
if (number): |
|
for word in question: |
|
if (word == table[i][j]): |
|
match[i] = 1.0 |
|
answer[i][j] = 1.0 |
|
else: |
|
table_entry = table[i][j] |
|
for k in range(len(question)): |
|
if (k + len(table_entry) <= len(question)): |
|
if (table_entry == question[k:(k + len(table_entry))]): |
|
|
|
|
|
match[i] = 1.0 |
|
answer[i][j] = 1.0 |
|
matched_indices.append((k, len(table_entry))) |
|
return answer, match, matched_indices |
|
|
|
|
|
def partial_column_match(question, table, number): |
|
answer = [] |
|
for i in range(len(table)): |
|
answer.append(0) |
|
for i in range(len(table)): |
|
for word in question: |
|
if (word in table[i]): |
|
answer[i] = 1.0 |
|
return answer |
|
|
|
|
|
def exact_column_match(question, table, number): |
|
|
|
answer = [] |
|
matched_indices = [] |
|
for i in range(len(table)): |
|
answer.append(0) |
|
for i in range(len(table)): |
|
table_entry = table[i] |
|
for k in range(len(question)): |
|
if (k + len(table_entry) <= len(question)): |
|
if (table_entry == question[k:(k + len(table_entry))]): |
|
answer[i] = 1.0 |
|
matched_indices.append((k, len(table_entry))) |
|
return answer, matched_indices |
|
|
|
|
|
def get_max_entry(a): |
|
e = {} |
|
for w in a: |
|
if (w != "UNK, "): |
|
if (e.has_key(w)): |
|
e[w] += 1 |
|
else: |
|
e[w] = 1 |
|
if (len(e) > 0): |
|
(key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0] |
|
if (val > 1): |
|
return key |
|
else: |
|
return -1.0 |
|
else: |
|
return -1.0 |
|
|
|
|
|
def list_join(a): |
|
ans = "" |
|
for w in a: |
|
ans += str(w) + ", " |
|
return ans |
|
|
|
|
|
def group_by_max(table, number): |
|
|
|
answer = [] |
|
for i in range(len(table)): |
|
temp = [] |
|
for j in range(len(table[i])): |
|
temp.append(0) |
|
answer.append(temp) |
|
for i in range(len(table)): |
|
if (number): |
|
curr = table[i] |
|
else: |
|
curr = [list_join(w) for w in table[i]] |
|
max_entry = get_max_entry(curr) |
|
|
|
for j in range(len(curr)): |
|
if (max_entry == curr[j]): |
|
answer[i][j] = 1.0 |
|
else: |
|
answer[i][j] = 0.0 |
|
return answer |
|
|
|
|
|
def pick_one(a): |
|
for i in range(len(a)): |
|
if (1.0 in a[i]): |
|
return True |
|
return False |
|
|
|
|
|
def check_processed_cols(col, utility): |
|
return True in [ |
|
True for y in col |
|
if (y != utility.FLAGS.pad_int and y != |
|
utility.FLAGS.bad_number_pre_process) |
|
] |
|
|
|
|
|
def complete_wiki_processing(data, utility, train=True): |
|
|
|
processed_data = [] |
|
num_bad_examples = 0 |
|
for example in data: |
|
number_found = 0 |
|
if (example.is_bad_example): |
|
num_bad_examples += 1 |
|
if (not (example.is_bad_example)): |
|
example.string_question = example.question[:] |
|
|
|
example.processed_number_columns = example.processed_number_columns[:] |
|
example.processed_word_columns = example.processed_word_columns[:] |
|
example.word_exact_match, word_match, matched_indices = exact_match( |
|
example.string_question, example.original_wc, number=False) |
|
example.number_exact_match, number_match, _ = exact_match( |
|
example.string_question, example.original_nc, number=True) |
|
if (not (pick_one(example.word_exact_match)) and not ( |
|
pick_one(example.number_exact_match))): |
|
assert len(word_match) == 0 |
|
assert len(number_match) == 0 |
|
example.word_exact_match, word_match = partial_match( |
|
example.string_question, example.original_wc, number=False) |
|
|
|
example.word_group_by_max = group_by_max(example.original_wc, False) |
|
example.number_group_by_max = group_by_max(example.original_nc, True) |
|
|
|
example.word_column_exact_match, wcol_matched_indices = exact_column_match( |
|
example.string_question, example.original_wc_names, number=False) |
|
example.number_column_exact_match, ncol_matched_indices = exact_column_match( |
|
example.string_question, example.original_nc_names, number=False) |
|
if (not (1.0 in example.word_column_exact_match) and not ( |
|
1.0 in example.number_column_exact_match)): |
|
example.word_column_exact_match = partial_column_match( |
|
example.string_question, example.original_wc_names, number=False) |
|
example.number_column_exact_match = partial_column_match( |
|
example.string_question, example.original_nc_names, number=False) |
|
if (len(word_match) > 0 or len(number_match) > 0): |
|
example.question.append(utility.entry_match_token) |
|
if (1.0 in example.word_column_exact_match or |
|
1.0 in example.number_column_exact_match): |
|
example.question.append(utility.column_match_token) |
|
example.string_question = example.question[:] |
|
example.number_lookup_matrix = np.transpose( |
|
example.number_lookup_matrix)[:] |
|
example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:] |
|
example.columns = example.number_columns[:] |
|
example.word_columns = example.word_columns[:] |
|
example.len_total_cols = len(example.word_column_names) + len( |
|
example.number_column_names) |
|
example.column_names = example.number_column_names[:] |
|
example.word_column_names = example.word_column_names[:] |
|
example.string_column_names = example.number_column_names[:] |
|
example.string_word_column_names = example.word_column_names[:] |
|
example.sorted_number_index = [] |
|
example.sorted_word_index = [] |
|
example.column_mask = [] |
|
example.word_column_mask = [] |
|
example.processed_column_mask = [] |
|
example.processed_word_column_mask = [] |
|
example.word_column_entry_mask = [] |
|
example.question_attention_mask = [] |
|
example.question_number = example.question_number_1 = -1 |
|
example.question_attention_mask = [] |
|
example.ordinal_question = [] |
|
example.ordinal_question_one = [] |
|
new_question = [] |
|
if (len(example.number_columns) > 0): |
|
example.len_col = len(example.number_columns[0]) |
|
else: |
|
example.len_col = len(example.word_columns[0]) |
|
for (start, length) in matched_indices: |
|
for j in range(length): |
|
example.question[start + j] = utility.unk_token |
|
|
|
for word in example.question: |
|
if (isinstance(word, numbers.Number) or wiki_data.is_date(word)): |
|
if (not (isinstance(word, numbers.Number)) and |
|
wiki_data.is_date(word)): |
|
word = word.replace("X", "").replace("-", "") |
|
number_found += 1 |
|
if (number_found == 1): |
|
example.question_number = word |
|
if (len(example.ordinal_question) > 0): |
|
example.ordinal_question[len(example.ordinal_question) - 1] = 1.0 |
|
else: |
|
example.ordinal_question.append(1.0) |
|
elif (number_found == 2): |
|
example.question_number_1 = word |
|
if (len(example.ordinal_question_one) > 0): |
|
example.ordinal_question_one[len(example.ordinal_question_one) - |
|
1] = 1.0 |
|
else: |
|
example.ordinal_question_one.append(1.0) |
|
else: |
|
new_question.append(word) |
|
example.ordinal_question.append(0.0) |
|
example.ordinal_question_one.append(0.0) |
|
example.question = [ |
|
utility.word_ids[word_lookup(w, utility)] for w in new_question |
|
] |
|
example.question_attention_mask = [0.0] * len(example.question) |
|
|
|
example.ordinal_question = example.ordinal_question[0:len( |
|
example.question)] |
|
example.ordinal_question_one = example.ordinal_question_one[0:len( |
|
example.question)] |
|
|
|
example.question = [utility.word_ids[utility.dummy_token]] * ( |
|
utility.FLAGS.question_length - len(example.question) |
|
) + example.question |
|
example.question_attention_mask = [-10000.0] * ( |
|
utility.FLAGS.question_length - len(example.question_attention_mask) |
|
) + example.question_attention_mask |
|
example.ordinal_question = [0.0] * (utility.FLAGS.question_length - |
|
len(example.ordinal_question) |
|
) + example.ordinal_question |
|
example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length - |
|
len(example.ordinal_question_one) |
|
) + example.ordinal_question_one |
|
if (True): |
|
|
|
num_cols = len(example.columns) |
|
start = 0 |
|
for column in example.number_columns: |
|
if (check_processed_cols(example.processed_number_columns[start], |
|
utility)): |
|
example.processed_column_mask.append(0.0) |
|
sorted_index = sorted( |
|
range(len(example.processed_number_columns[start])), |
|
key=lambda k: example.processed_number_columns[start][k], |
|
reverse=True) |
|
sorted_index = sorted_index + [utility.FLAGS.pad_int] * ( |
|
utility.FLAGS.max_elements - len(sorted_index)) |
|
example.sorted_number_index.append(sorted_index) |
|
example.columns[start] = column + [utility.FLAGS.pad_int] * ( |
|
utility.FLAGS.max_elements - len(column)) |
|
example.processed_number_columns[start] += [utility.FLAGS.pad_int] * ( |
|
utility.FLAGS.max_elements - |
|
len(example.processed_number_columns[start])) |
|
start += 1 |
|
example.column_mask.append(0.0) |
|
for remaining in range(num_cols, utility.FLAGS.max_number_cols): |
|
example.sorted_number_index.append([utility.FLAGS.pad_int] * |
|
(utility.FLAGS.max_elements)) |
|
example.columns.append([utility.FLAGS.pad_int] * |
|
(utility.FLAGS.max_elements)) |
|
example.processed_number_columns.append([utility.FLAGS.pad_int] * |
|
(utility.FLAGS.max_elements)) |
|
example.number_exact_match.append([0.0] * |
|
(utility.FLAGS.max_elements)) |
|
example.number_group_by_max.append([0.0] * |
|
(utility.FLAGS.max_elements)) |
|
example.column_mask.append(-100000000.0) |
|
example.processed_column_mask.append(-100000000.0) |
|
example.number_column_exact_match.append(0.0) |
|
example.column_names.append([utility.dummy_token]) |
|
|
|
start = 0 |
|
word_num_cols = len(example.word_columns) |
|
for column in example.word_columns: |
|
if (check_processed_cols(example.processed_word_columns[start], |
|
utility)): |
|
example.processed_word_column_mask.append(0.0) |
|
sorted_index = sorted( |
|
range(len(example.processed_word_columns[start])), |
|
key=lambda k: example.processed_word_columns[start][k], |
|
reverse=True) |
|
sorted_index = sorted_index + [utility.FLAGS.pad_int] * ( |
|
utility.FLAGS.max_elements - len(sorted_index)) |
|
example.sorted_word_index.append(sorted_index) |
|
column = convert_to_int_2d_and_pad(column, utility) |
|
example.word_columns[start] = column + [[ |
|
utility.word_ids[utility.dummy_token] |
|
] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements - |
|
len(column)) |
|
example.processed_word_columns[start] += [utility.FLAGS.pad_int] * ( |
|
utility.FLAGS.max_elements - |
|
len(example.processed_word_columns[start])) |
|
example.word_column_entry_mask.append([0] * len(column) + [ |
|
utility.word_ids[utility.dummy_token] |
|
] * (utility.FLAGS.max_elements - len(column))) |
|
start += 1 |
|
example.word_column_mask.append(0.0) |
|
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols): |
|
example.sorted_word_index.append([utility.FLAGS.pad_int] * |
|
(utility.FLAGS.max_elements)) |
|
example.word_columns.append([[utility.word_ids[utility.dummy_token]] * |
|
utility.FLAGS.max_entry_length] * |
|
(utility.FLAGS.max_elements)) |
|
example.word_column_entry_mask.append( |
|
[utility.word_ids[utility.dummy_token]] * |
|
(utility.FLAGS.max_elements)) |
|
example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements)) |
|
example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements)) |
|
example.processed_word_columns.append([utility.FLAGS.pad_int] * |
|
(utility.FLAGS.max_elements)) |
|
example.word_column_mask.append(-100000000.0) |
|
example.processed_word_column_mask.append(-100000000.0) |
|
example.word_column_exact_match.append(0.0) |
|
example.word_column_names.append([utility.dummy_token] * |
|
utility.FLAGS.max_entry_length) |
|
seen_tables[example.table_key] = 1 |
|
|
|
example.column_ids = convert_to_int_2d_and_pad(example.column_names, |
|
utility) |
|
example.word_column_ids = convert_to_int_2d_and_pad( |
|
example.word_column_names, utility) |
|
for i_em in range(len(example.number_exact_match)): |
|
example.number_exact_match[i_em] = example.number_exact_match[ |
|
i_em] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.number_exact_match[i_em])) |
|
example.number_group_by_max[i_em] = example.number_group_by_max[ |
|
i_em] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.number_group_by_max[i_em])) |
|
for i_em in range(len(example.word_exact_match)): |
|
example.word_exact_match[i_em] = example.word_exact_match[ |
|
i_em] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.word_exact_match[i_em])) |
|
example.word_group_by_max[i_em] = example.word_group_by_max[ |
|
i_em] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.word_group_by_max[i_em])) |
|
example.exact_match = example.number_exact_match + example.word_exact_match |
|
example.group_by_max = example.number_group_by_max + example.word_group_by_max |
|
example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match |
|
|
|
if (example.is_lookup): |
|
example.answer = example.calc_answer |
|
example.number_print_answer = example.number_lookup_matrix.tolist() |
|
example.word_print_answer = example.word_lookup_matrix.tolist() |
|
for i_answer in range(len(example.number_print_answer)): |
|
example.number_print_answer[i_answer] = example.number_print_answer[ |
|
i_answer] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.number_print_answer[i_answer])) |
|
for i_answer in range(len(example.word_print_answer)): |
|
example.word_print_answer[i_answer] = example.word_print_answer[ |
|
i_answer] + [0.0] * (utility.FLAGS.max_elements - |
|
len(example.word_print_answer[i_answer])) |
|
example.number_lookup_matrix = convert_to_bool_and_pad( |
|
example.number_lookup_matrix, utility) |
|
example.word_lookup_matrix = convert_to_bool_and_pad( |
|
example.word_lookup_matrix, utility) |
|
for remaining in range(num_cols, utility.FLAGS.max_number_cols): |
|
example.number_lookup_matrix.append([False] * |
|
utility.FLAGS.max_elements) |
|
example.number_print_answer.append([0.0] * utility.FLAGS.max_elements) |
|
for remaining in range(word_num_cols, utility.FLAGS.max_word_cols): |
|
example.word_lookup_matrix.append([False] * |
|
utility.FLAGS.max_elements) |
|
example.word_print_answer.append([0.0] * utility.FLAGS.max_elements) |
|
example.print_answer = example.number_print_answer + example.word_print_answer |
|
else: |
|
example.answer = example.calc_answer |
|
example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * ( |
|
utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols) |
|
|
|
if (example.question_number == -1): |
|
example.question_number_mask = np.zeros([utility.FLAGS.max_elements]) |
|
else: |
|
example.question_number_mask = np.ones([utility.FLAGS.max_elements]) |
|
if (example.question_number_1 == -1): |
|
example.question_number_one_mask = -10000.0 |
|
else: |
|
example.question_number_one_mask = np.float64(0.0) |
|
if (example.len_col > utility.FLAGS.max_elements): |
|
continue |
|
processed_data.append(example) |
|
return processed_data |
|
|
|
|
|
def add_special_words(utility): |
|
utility.words.append(utility.entry_match_token) |
|
utility.word_ids[utility.entry_match_token] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[ |
|
utility.entry_match_token]] = utility.entry_match_token |
|
utility.entry_match_token_id = utility.word_ids[utility.entry_match_token] |
|
print("entry match token: ", utility.word_ids[ |
|
utility.entry_match_token], utility.entry_match_token_id) |
|
utility.words.append(utility.column_match_token) |
|
utility.word_ids[utility.column_match_token] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[ |
|
utility.column_match_token]] = utility.column_match_token |
|
utility.column_match_token_id = utility.word_ids[utility.column_match_token] |
|
print("entry match token: ", utility.word_ids[ |
|
utility.column_match_token], utility.column_match_token_id) |
|
utility.words.append(utility.dummy_token) |
|
utility.word_ids[utility.dummy_token] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[ |
|
utility.dummy_token]] = utility.dummy_token |
|
utility.dummy_token_id = utility.word_ids[utility.dummy_token] |
|
utility.words.append(utility.unk_token) |
|
utility.word_ids[utility.unk_token] = len(utility.word_ids) |
|
utility.reverse_word_ids[utility.word_ids[ |
|
utility.unk_token]] = utility.unk_token |
|
|
|
|
|
def perform_word_cutoff(utility): |
|
if (utility.FLAGS.word_cutoff > 0): |
|
for word in utility.word_ids.keys(): |
|
if (utility.word_count.has_key(word) and utility.word_count[word] < |
|
utility.FLAGS.word_cutoff and word != utility.unk_token and |
|
word != utility.dummy_token and word != utility.entry_match_token and |
|
word != utility.column_match_token): |
|
utility.word_ids.pop(word) |
|
utility.words.remove(word) |
|
|
|
|
|
def word_dropout(question, utility): |
|
if (utility.FLAGS.word_dropout_prob > 0.0): |
|
new_question = [] |
|
for i in range(len(question)): |
|
if (question[i] != utility.dummy_token_id and |
|
utility.random.random() > utility.FLAGS.word_dropout_prob): |
|
new_question.append(utility.word_ids[utility.unk_token]) |
|
else: |
|
new_question.append(question[i]) |
|
return new_question |
|
else: |
|
return question |
|
|
|
|
|
def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None): |
|
|
|
feed_dict = {} |
|
feed_examples = [] |
|
for j in range(batch_size): |
|
feed_examples.append(data[curr + j]) |
|
if (train): |
|
feed_dict[gr.batch_question] = [ |
|
word_dropout(feed_examples[j].question, utility) |
|
for j in range(batch_size) |
|
] |
|
else: |
|
feed_dict[gr.batch_question] = [ |
|
feed_examples[j].question for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_question_attention_mask] = [ |
|
feed_examples[j].question_attention_mask for j in range(batch_size) |
|
] |
|
feed_dict[ |
|
gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)] |
|
feed_dict[gr.batch_number_column] = [ |
|
feed_examples[j].columns for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_processed_number_column] = [ |
|
feed_examples[j].processed_number_columns for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_processed_sorted_index_number_column] = [ |
|
feed_examples[j].sorted_number_index for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_processed_sorted_index_word_column] = [ |
|
feed_examples[j].sorted_word_index for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_question_number] = np.array( |
|
[feed_examples[j].question_number for j in range(batch_size)]).reshape( |
|
(batch_size, 1)) |
|
feed_dict[gr.batch_question_number_one] = np.array( |
|
[feed_examples[j].question_number_1 for j in range(batch_size)]).reshape( |
|
(batch_size, 1)) |
|
feed_dict[gr.batch_question_number_mask] = [ |
|
feed_examples[j].question_number_mask for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_question_number_one_mask] = np.array( |
|
[feed_examples[j].question_number_one_mask for j in range(batch_size) |
|
]).reshape((batch_size, 1)) |
|
feed_dict[gr.batch_print_answer] = [ |
|
feed_examples[j].print_answer for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_exact_match] = [ |
|
feed_examples[j].exact_match for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_group_by_max] = [ |
|
feed_examples[j].group_by_max for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_column_exact_match] = [ |
|
feed_examples[j].exact_column_match for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_ordinal_question] = [ |
|
feed_examples[j].ordinal_question for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_ordinal_question_one] = [ |
|
feed_examples[j].ordinal_question_one for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_number_column_mask] = [ |
|
feed_examples[j].column_mask for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_number_column_names] = [ |
|
feed_examples[j].column_ids for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_processed_word_column] = [ |
|
feed_examples[j].processed_word_columns for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_word_column_mask] = [ |
|
feed_examples[j].word_column_mask for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_word_column_names] = [ |
|
feed_examples[j].word_column_ids for j in range(batch_size) |
|
] |
|
feed_dict[gr.batch_word_column_entry_mask] = [ |
|
feed_examples[j].word_column_entry_mask for j in range(batch_size) |
|
] |
|
return feed_dict |
|
|