# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Create TFRecord files of SequenceExample protos from dataset. Constructs 3 datasets: 1. Labeled data for the LSTM classification model, optionally with label gain. "*_classification.tfrecords" (for both unidirectional and bidirectional models). 2. Data for the unsupervised LM-LSTM model that predicts the next token. "*_lm.tfrecords" (generates forward and reverse data). 3. Data for the unsupervised SA-LSTM model that uses Seq2Seq. "*_sa.tfrecords". """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import string # Dependency imports import tensorflow as tf from data import data_utils from data import document_generators data = data_utils flags = tf.app.flags FLAGS = flags.FLAGS # Flags for input data are in document_generators.py flags.DEFINE_string('vocab_file', '', 'Path to the vocabulary file. Defaults ' 'to FLAGS.output_dir/vocab.txt.') flags.DEFINE_string('output_dir', '', 'Path to save tfrecords.') # Config flags.DEFINE_boolean('label_gain', False, 'Enable linear label gain. If True, sentiment label will ' 'be included at each timestep with linear weight ' 'increase.') def build_shuffling_tf_record_writer(fname): return data.ShufflingTFRecordWriter(os.path.join(FLAGS.output_dir, fname)) def build_tf_record_writer(fname): return tf.python_io.TFRecordWriter(os.path.join(FLAGS.output_dir, fname)) def build_input_sequence(doc, vocab_ids): """Builds input sequence from file. Splits lines on whitespace. Treats punctuation as whitespace. For word-level sequences, only keeps terms that are in the vocab. Terms are added as token in the SequenceExample. The EOS_TOKEN is also appended. Label and weight features are set to 0. Args: doc: Document (defined in `document_generators`) from which to build the sequence. vocab_ids: dict. Returns: SequenceExampleWrapper. """ seq = data.SequenceWrapper() for token in document_generators.tokens(doc): if token in vocab_ids: seq.add_timestep().set_token(vocab_ids[token]) # Add EOS token to end seq.add_timestep().set_token(vocab_ids[data.EOS_TOKEN]) return seq def make_vocab_ids(vocab_filename): if FLAGS.output_char: ret = dict([(char, i) for i, char in enumerate(string.printable)]) ret[data.EOS_TOKEN] = len(string.printable) return ret else: with open(vocab_filename, encoding='utf-8') as vocab_f: return dict([(line.strip(), i) for i, line in enumerate(vocab_f)]) def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all): """Generates training data.""" # Construct training data writers writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM) writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA) writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS) writer_valid_class = build_tf_record_writer(data.VALID_CLASS) writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM) writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS) writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS) for doc in document_generators.documents( dataset='train', include_unlabeled=True, include_validation=True): input_seq = build_input_sequence(doc, vocab_ids) if len(input_seq) < 2: continue rev_seq = data.build_reverse_sequence(input_seq) lm_seq = data.build_lm_sequence(input_seq) rev_lm_seq = data.build_lm_sequence(rev_seq) seq_ae_seq = data.build_seq_ae_sequence(input_seq) if doc.label is not None: # Used for sentiment classification. label_seq = data.build_labeled_sequence( input_seq, doc.label, label_gain=(FLAGS.label_gain and not doc.is_validation)) bd_label_seq = data.build_labeled_sequence( data.build_bidirectional_seq(input_seq, rev_seq), doc.label, label_gain=(FLAGS.label_gain and not doc.is_validation)) class_writer = writer_valid_class if doc.is_validation else writer_class bd_class_writer = (writer_bd_valid_class if doc.is_validation else writer_bd_class) class_writer.write(label_seq.seq.SerializeToString()) bd_class_writer.write(bd_label_seq.seq.SerializeToString()) # Write lm_seq_ser = lm_seq.seq.SerializeToString() seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() writer_lm_all.write(lm_seq_ser) writer_seq_ae_all.write(seq_ae_seq_ser) if not doc.is_validation: writer_lm.write(lm_seq_ser) writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) writer_seq_ae.write(seq_ae_seq_ser) # Close writers writer_lm.close() writer_seq_ae.close() writer_class.close() writer_valid_class.close() writer_rev_lm.close() writer_bd_class.close() writer_bd_valid_class.close() def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all): """Generates test data.""" # Construct test data writers writer_lm = build_shuffling_tf_record_writer(data.TEST_LM) writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM) writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA) writer_class = build_tf_record_writer(data.TEST_CLASS) writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS) for doc in document_generators.documents( dataset='test', include_unlabeled=False, include_validation=True): input_seq = build_input_sequence(doc, vocab_ids) if len(input_seq) < 2: continue rev_seq = data.build_reverse_sequence(input_seq) lm_seq = data.build_lm_sequence(input_seq) rev_lm_seq = data.build_lm_sequence(rev_seq) seq_ae_seq = data.build_seq_ae_sequence(input_seq) label_seq = data.build_labeled_sequence(input_seq, doc.label) bd_label_seq = data.build_labeled_sequence( data.build_bidirectional_seq(input_seq, rev_seq), doc.label) # Write writer_class.write(label_seq.seq.SerializeToString()) writer_bd_class.write(bd_label_seq.seq.SerializeToString()) lm_seq_ser = lm_seq.seq.SerializeToString() seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString() writer_lm.write(lm_seq_ser) writer_rev_lm.write(rev_lm_seq.seq.SerializeToString()) writer_seq_ae.write(seq_ae_seq_ser) writer_lm_all.write(lm_seq_ser) writer_seq_ae_all.write(seq_ae_seq_ser) # Close test writers writer_lm.close() writer_rev_lm.close() writer_seq_ae.close() writer_class.close() writer_bd_class.close() def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.logging.info('Assigning vocabulary ids...') vocab_ids = make_vocab_ids( FLAGS.vocab_file or os.path.join(FLAGS.output_dir, 'vocab.txt')) with build_shuffling_tf_record_writer(data.ALL_LM) as writer_lm_all: with build_shuffling_tf_record_writer(data.ALL_SA) as writer_seq_ae_all: tf.logging.info('Generating training data...') generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all) tf.logging.info('Generating test data...') generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all) if __name__ == '__main__': tf.app.run()