|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Create TFRecord files of SequenceExample protos from dataset. |
|
|
|
Constructs 3 datasets: |
|
1. Labeled data for the LSTM classification model, optionally with label gain. |
|
"*_classification.tfrecords" (for both unidirectional and bidirectional |
|
models). |
|
2. Data for the unsupervised LM-LSTM model that predicts the next token. |
|
"*_lm.tfrecords" (generates forward and reverse data). |
|
3. Data for the unsupervised SA-LSTM model that uses Seq2Seq. |
|
"*_sa.tfrecords". |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import string |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
from data import data_utils |
|
from data import document_generators |
|
|
|
data = data_utils |
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults ' |
|
'to FLAGS.output_dir/vocab.txt.') |
|
flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.') |
|
|
|
|
|
flags.DEFINE_boolean('label_gain', False, |
|
'Enable linear label gain. If True, sentiment label will ' |
|
'be included at each timestep with linear weight ' |
|
'increase.') |
|
|
|
|
|
def build_shuffling_tf_record_writer(fname): |
|
return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname)) |
|
|
|
|
|
def build_tf_record_writer(fname): |
|
return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname)) |
|
|
|
|
|
def build_input_sequence(doc, vocab_ids): |
|
"""Builds input sequence from file. |
|
|
|
Splits lines on whitespace. Treats punctuation as whitespace. For word-level |
|
sequences, only keeps terms that are in the vocab. |
|
|
|
Terms are added as token in the SequenceExample. The EOS_TOKEN is also |
|
appended. Label and weight features are set to 0. |
|
|
|
Args: |
|
doc: Document (defined in `document_generators`) from which to build the |
|
sequence. |
|
vocab_ids: dict<term, id>. |
|
|
|
Returns: |
|
SequenceExampleWrapper. |
|
""" |
|
seq = data.SequenceWrapper() |
|
for token in document_generators.tokens(doc): |
|
if token in vocab_ids: |
|
seq.add_timestep().set_token(vocab_ids[token]) |
|
|
|
|
|
seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN]) |
|
|
|
return seq |
|
|
|
|
|
def make_vocab_ids(vocab_filename): |
|
if FLAGS.output_char: |
|
ret = dict([(char, i) for i, char in enumerate(string.printable)]) |
|
ret[data.EOS_TOKEN] = len(string.printable) |
|
return ret |
|
else: |
|
with open(vocab_filename, encoding='utf-8') as vocab_f: |
|
return dict([(line.strip(), i) for i, line in enumerate(vocab_f)]) |
|
|
|
|
|
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all): |
|
"""Generates training data.""" |
|
|
|
|
|
writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM) |
|
writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA) |
|
writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS) |
|
writer_valid_class = build_tf_record_writer(data.VALID_CLASS) |
|
writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM) |
|
writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS) |
|
writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS) |
|
|
|
for doc in document_generators.documents( |
|
dataset='train', include_unlabeled=True, include_validation=True): |
|
input_seq = build_input_sequence(doc, vocab_ids) |
|
if len(input_seq) < 2: |
|
continue |
|
rev_seq = data.build_reverse_sequence(input_seq) |
|
lm_seq = data.build_lm_sequence(input_seq) |
|
rev_lm_seq = data.build_lm_sequence(rev_seq) |
|
seq_ae_seq = data.build_seq_ae_sequence(input_seq) |
|
if doc.label is not None: |
|
|
|
label_seq = data.build_labeled_sequence( |
|
input_seq, |
|
doc.label, |
|
label_gain=(FLAGS.label_gain and not doc.is_validation)) |
|
bd_label_seq = data.build_labeled_sequence( |
|
data.build_bidirectional_seq(input_seq, rev_seq), |
|
doc.label, |
|
label_gain=(FLAGS.label_gain and not doc.is_validation)) |
|
class_writer = writer_valid_class if doc.is_validation else writer_class |
|
bd_class_writer = (writer_bd_valid_class |
|
if doc.is_validation else writer_bd_class) |
|
class_writer.write(label_seq.seq.SerializeToString()) |
|
bd_class_writer.write(bd_label_seq.seq.SerializeToString()) |
|
|
|
|
|
lm_seq_ser = lm_seq.seq.SerializeToString() |
|
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() |
|
writer_lm_all.write(lm_seq_ser) |
|
writer_seq_ae_all.write(seq_ae_seq_ser) |
|
if not doc.is_validation: |
|
writer_lm.write(lm_seq_ser) |
|
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) |
|
writer_seq_ae.write(seq_ae_seq_ser) |
|
|
|
|
|
writer_lm.close() |
|
writer_seq_ae.close() |
|
writer_class.close() |
|
writer_valid_class.close() |
|
writer_rev_lm.close() |
|
writer_bd_class.close() |
|
writer_bd_valid_class.close() |
|
|
|
|
|
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): |
|
"""Generates test data.""" |
|
|
|
writer_lm = build_shuffling_tf_record_writer(data.TEST_LM) |
|
writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM) |
|
writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA) |
|
writer_class = build_tf_record_writer(data.TEST_CLASS) |
|
writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS) |
|
|
|
for doc in document_generators.documents( |
|
dataset='test', include_unlabeled=False, include_validation=True): |
|
input_seq = build_input_sequence(doc, vocab_ids) |
|
if len(input_seq) < 2: |
|
continue |
|
rev_seq = data.build_reverse_sequence(input_seq) |
|
lm_seq = data.build_lm_sequence(input_seq) |
|
rev_lm_seq = data.build_lm_sequence(rev_seq) |
|
seq_ae_seq = data.build_seq_ae_sequence(input_seq) |
|
label_seq = data.build_labeled_sequence(input_seq, doc.label) |
|
bd_label_seq = data.build_labeled_sequence( |
|
data.build_bidirectional_seq(input_seq, rev_seq), doc.label) |
|
|
|
|
|
writer_class.write(label_seq.seq.SerializeToString()) |
|
writer_bd_class.write(bd_label_seq.seq.SerializeToString()) |
|
lm_seq_ser = lm_seq.seq.SerializeToString() |
|
seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() |
|
writer_lm.write(lm_seq_ser) |
|
writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) |
|
writer_seq_ae.write(seq_ae_seq_ser) |
|
writer_lm_all.write(lm_seq_ser) |
|
writer_seq_ae_all.write(seq_ae_seq_ser) |
|
|
|
|
|
writer_lm.close() |
|
writer_rev_lm.close() |
|
writer_seq_ae.close() |
|
writer_class.close() |
|
writer_bd_class.close() |
|
|
|
|
|
def main(_): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
tf.logging.info('Assigning vocabulary ids...') |
|
vocab_ids = make_vocab_ids( |
|
FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) |
|
|
|
with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all: |
|
with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all: |
|
|
|
tf.logging.info('Generating training data...') |
|
generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all) |
|
|
|
tf.logging.info('Generating test data...') |
|
generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|