|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import json |
|
import os |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
FLAGS = tf.flags.FLAGS |
|
|
|
|
|
class Vocabulary(object): |
|
"""Class that holds a vocabulary for the dataset.""" |
|
|
|
def __init__(self, filename): |
|
|
|
self._id_to_word = [] |
|
self._word_to_id = {} |
|
self._unk = -1 |
|
self._bos = -1 |
|
self._eos = -1 |
|
|
|
with tf.gfile.Open(filename) as f: |
|
idx = 0 |
|
for line in f: |
|
word_name = line.strip() |
|
if word_name == '<S>': |
|
self._bos = idx |
|
elif word_name == '</S>': |
|
self._eos = idx |
|
elif word_name == '<UNK>': |
|
self._unk = idx |
|
if word_name == '!!!MAXTERMID': |
|
continue |
|
|
|
self._id_to_word.append(word_name) |
|
self._word_to_id[word_name] = idx |
|
idx += 1 |
|
|
|
@property |
|
def bos(self): |
|
return self._bos |
|
|
|
@property |
|
def eos(self): |
|
return self._eos |
|
|
|
@property |
|
def unk(self): |
|
return self._unk |
|
|
|
@property |
|
def size(self): |
|
return len(self._id_to_word) |
|
|
|
def word_to_id(self, word): |
|
if word in self._word_to_id: |
|
return self._word_to_id[word] |
|
else: |
|
if word.lower() in self._word_to_id: |
|
return self._word_to_id[word.lower()] |
|
return self.unk |
|
|
|
def id_to_word(self, cur_id): |
|
if cur_id < self.size: |
|
return self._id_to_word[int(cur_id)] |
|
return '<ERROR_out_of_vocab_id>' |
|
|
|
def decode(self, cur_ids): |
|
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids]) |
|
|
|
def encode(self, sentence): |
|
word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()] |
|
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32) |
|
|
|
|
|
class CharsVocabulary(Vocabulary): |
|
"""Vocabulary containing character-level information.""" |
|
|
|
def __init__(self, filename, max_word_length): |
|
super(CharsVocabulary, self).__init__(filename) |
|
|
|
self._max_word_length = max_word_length |
|
chars_set = set() |
|
|
|
for word in self._id_to_word: |
|
chars_set |= set(word) |
|
|
|
free_ids = [] |
|
for i in range(256): |
|
if chr(i) in chars_set: |
|
continue |
|
free_ids.append(chr(i)) |
|
|
|
if len(free_ids) < 5: |
|
raise ValueError('Not enough free char ids: %d' % len(free_ids)) |
|
|
|
self.bos_char = free_ids[0] |
|
self.eos_char = free_ids[1] |
|
self.bow_char = free_ids[2] |
|
self.eow_char = free_ids[3] |
|
self.pad_char = free_ids[4] |
|
|
|
chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char, |
|
self.pad_char} |
|
|
|
self._char_set = chars_set |
|
num_words = len(self._id_to_word) |
|
|
|
self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32) |
|
|
|
self.bos_chars = self._convert_word_to_char_ids(self.bos_char) |
|
self.eos_chars = self._convert_word_to_char_ids(self.eos_char) |
|
|
|
for i, word in enumerate(self._id_to_word): |
|
if i == self.bos: |
|
self._word_char_ids[i] = self.bos_chars |
|
elif i == self.eos: |
|
self._word_char_ids[i] = self.eos_chars |
|
else: |
|
self._word_char_ids[i] = self._convert_word_to_char_ids(word) |
|
|
|
@property |
|
def max_word_length(self): |
|
return self._max_word_length |
|
|
|
def _convert_word_to_char_ids(self, word): |
|
code = np.zeros([self.max_word_length], dtype=np.int32) |
|
code[:] = ord(self.pad_char) |
|
|
|
if len(word) > self.max_word_length - 2: |
|
word = word[:self.max_word_length-2] |
|
cur_word = self.bow_char + word + self.eow_char |
|
for j in range(len(cur_word)): |
|
code[j] = ord(cur_word[j]) |
|
return code |
|
|
|
def word_to_char_ids(self, word): |
|
if word in self._word_to_id: |
|
return self._word_char_ids[self._word_to_id[word]] |
|
else: |
|
return self._convert_word_to_char_ids(word) |
|
|
|
def encode_chars(self, sentence): |
|
chars_ids = [self.word_to_char_ids(cur_word) |
|
for cur_word in sentence.split()] |
|
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars]) |
|
|
|
|
|
_SPECIAL_CHAR_MAP = { |
|
'\xe2\x80\x98': '\'', |
|
'\xe2\x80\x99': '\'', |
|
'\xe2\x80\x9c': '"', |
|
'\xe2\x80\x9d': '"', |
|
'\xe2\x80\x93': '-', |
|
'\xe2\x80\x94': '-', |
|
'\xe2\x88\x92': '-', |
|
'\xce\x84': '\'', |
|
'\xc2\xb4': '\'', |
|
'`': '\'' |
|
} |
|
|
|
_START_SPECIAL_CHARS = ['.', ',', '?', '!', ';', ':', '[', ']', '\'', '+', '/', |
|
'\xc2\xa3', '$', '~', '*', '%', '{', '}', '#', '&', '-', |
|
'"', '(', ')', '='] + list(_SPECIAL_CHAR_MAP.keys()) |
|
_SPECIAL_CHARS = _START_SPECIAL_CHARS + [ |
|
'\'s', '\'m', '\'t', '\'re', '\'d', '\'ve', '\'ll'] |
|
|
|
|
|
def tokenize(sentence): |
|
"""Tokenize a sentence.""" |
|
sentence = str(sentence) |
|
words = sentence.strip().split() |
|
tokenized = [] |
|
|
|
for word in words: |
|
if word.lower() in ['mr.', 'ms.']: |
|
tokenized.append(word) |
|
continue |
|
|
|
|
|
will_split = True |
|
while will_split: |
|
will_split = False |
|
for char in _START_SPECIAL_CHARS: |
|
if word.startswith(char): |
|
tokenized.append(char) |
|
word = word[len(char):] |
|
will_split = True |
|
|
|
|
|
special_end_tokens = [] |
|
will_split = True |
|
while will_split: |
|
will_split = False |
|
for char in _SPECIAL_CHARS: |
|
if word.endswith(char): |
|
special_end_tokens = [char] + special_end_tokens |
|
word = word[:-len(char)] |
|
will_split = True |
|
|
|
if word: |
|
tokenized.append(word) |
|
tokenized += special_end_tokens |
|
|
|
|
|
if tokenized[-1] not in ['.', '!', '?']: |
|
tokenized += ['.'] |
|
return tokenized |
|
|
|
|
|
def parse_commonsense_reasoning_test(test_data_name): |
|
"""Read JSON test data.""" |
|
with tf.gfile.Open(os.path.join( |
|
FLAGS.data_dir, 'commonsense_test', |
|
'{}.json'.format(test_data_name)), 'r') as f: |
|
data = json.load(f) |
|
|
|
question_ids = [d['question_id'] for d in data] |
|
sentences = [tokenize(d['substitution']) for d in data] |
|
labels = [d['correctness'] for d in data] |
|
|
|
return question_ids, sentences, labels |
|
|
|
|
|
PAD = '<padding>' |
|
|
|
|
|
def cut_to_patches(sentences, batch_size, num_timesteps): |
|
"""Cut sentences into patches of shape (batch_size, num_timesteps). |
|
|
|
Args: |
|
sentences: a list of sentences, each sentence is a list of str token. |
|
batch_size: batch size |
|
num_timesteps: number of backprop step |
|
|
|
Returns: |
|
patches: A 2D matrix, |
|
each entry is a matrix of shape (batch_size, num_timesteps). |
|
""" |
|
preprocessed = [['<S>']+sentence+['</S>'] for sentence in sentences] |
|
max_len = max([len(sent) for sent in preprocessed]) |
|
|
|
|
|
|
|
|
|
nrow = int(np.ceil(len(preprocessed) * 1.0 / batch_size)) |
|
ncol = int(np.ceil(max_len * 1.0 / num_timesteps)) |
|
height, width = nrow * batch_size, ncol * num_timesteps + 1 |
|
preprocessed = [sent + [PAD] * (width - len(sent)) for sent in preprocessed] |
|
preprocessed += [[PAD] * width] * (height - len(preprocessed)) |
|
|
|
|
|
patches = [] |
|
for row in range(nrow): |
|
patches.append([]) |
|
for col in range(ncol): |
|
patch = [sent[col * num_timesteps: |
|
(col+1) * num_timesteps + 1] |
|
for sent in preprocessed[row * batch_size: |
|
(row+1) * batch_size]] |
|
if np.all(np.array(patch)[:, 1:] == PAD): |
|
patch = None |
|
patches[-1].append(patch) |
|
return patches |
|
|
|
|
|
def _substitution_mask(sent1, sent2): |
|
"""Binary mask identifying substituted part in two sentences. |
|
|
|
Example sentence and their mask: |
|
First sentence = "I like the cat 's color" |
|
0 0 0 1 0 0 |
|
Second sentence = "I like the yellow dog 's color" |
|
0 0 0 1 1 0 0 |
|
|
|
Args: |
|
sent1: first sentence |
|
sent2: second sentence |
|
|
|
Returns: |
|
mask1: mask for first sentence |
|
mask2: mask for second sentence |
|
""" |
|
mask1_start, mask2_start = [], [] |
|
while sent1[0] == sent2[0]: |
|
sent1 = sent1[1:] |
|
sent2 = sent2[1:] |
|
mask1_start.append(0.) |
|
mask2_start.append(0.) |
|
|
|
mask1_end, mask2_end = [], [] |
|
while sent1[-1] == sent2[-1]: |
|
if (len(sent1) == 1) or (len(sent2) == 1): |
|
break |
|
sent1 = sent1[:-1] |
|
sent2 = sent2[:-1] |
|
mask1_end = [0.] + mask1_end |
|
mask2_end = [0.] + mask2_end |
|
|
|
assert sent1 or sent2, 'Two sentences are identical.' |
|
return (mask1_start + [1.] * len(sent1) + mask1_end, |
|
mask2_start + [1.] * len(sent2) + mask2_end) |
|
|
|
|
|
def _convert_to_partial(scoring1, scoring2): |
|
"""Convert full scoring into partial scoring.""" |
|
mask1, mask2 = _substitution_mask( |
|
scoring1['sentence'], scoring2['sentence']) |
|
|
|
def _partial_score(scoring, mask): |
|
word_probs = [max(_) for _ in zip(scoring['word_probs'], mask)] |
|
scoring.update(word_probs=word_probs, |
|
joint_prob=np.prod(word_probs)) |
|
|
|
_partial_score(scoring1, mask1) |
|
_partial_score(scoring2, mask2) |
|
|
|
|
|
def compare_substitutions(question_ids, scorings, mode='full'): |
|
"""Return accuracy by comparing two consecutive scorings.""" |
|
prediction_correctness = [] |
|
|
|
for i in range(len(scorings) // 2): |
|
scoring1, scoring2 = scorings[2*i: 2*i+2] |
|
if mode == 'partial': |
|
_convert_to_partial(scoring1, scoring2) |
|
|
|
prediction_correctness.append( |
|
(scoring2['joint_prob'] > scoring1['joint_prob']) == |
|
scoring2['correctness']) |
|
|
|
|
|
question_ids = [qid for i, qid in enumerate(question_ids) if i % 2 == 0] |
|
assert len(question_ids) == len(prediction_correctness) |
|
num_questions = len(set(question_ids)) |
|
|
|
|
|
|
|
num_correct_answer = 0 |
|
previous_qid = None |
|
correctly_answered = False |
|
for predict, qid in zip(prediction_correctness, question_ids): |
|
if qid != previous_qid: |
|
previous_qid = qid |
|
num_correct_answer += int(correctly_answered) |
|
correctly_answered = True |
|
correctly_answered = correctly_answered and predict |
|
num_correct_answer += int(correctly_answered) |
|
|
|
return num_correct_answer / num_questions |
|
|