# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Input utils for virtual adversarial text classification.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os # Dependency imports import tensorflow as tf from data import data_utils class VatxtInput(object): """Wrapper around NextQueuedSequenceBatch.""" def __init__(self, batch, state_name=None, tokens=None, num_states=0, eos_id=None): """Construct VatxtInput. Args: batch: NextQueuedSequenceBatch. state_name: str, name of state to fetch and save. tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence. num_states: int The number of states to store. eos_id: int Id of end of Sequence. """ self._batch = batch self._state_name = state_name self._tokens = (tokens if tokens is not None else batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]) self._num_states = num_states w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT] self._weights = w l = batch.sequences[data_utils.SequenceWrapper.F_LABEL] self._labels = l # eos weights self._eos_weights = None if eos_id: ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32) self._eos_weights = ew @property def tokens(self): return self._tokens @property def weights(self): return self._weights @property def eos_weights(self): return self._eos_weights @property def labels(self): return self._labels @property def length(self): return self._batch.length @property def state_name(self): return self._state_name @property def state(self): # LSTM tuple states state_names = _get_tuple_state_names(self._num_states, self._state_name) return tuple([ tf.contrib.rnn.LSTMStateTuple( self._batch.state(c_name), self._batch.state(h_name)) for c_name, h_name in state_names ]) def save_state(self, value): # LSTM tuple states state_names = _get_tuple_state_names(self._num_states, self._state_name) save_ops = [] for (c_state, h_state), (c_name, h_name) in zip(value, state_names): save_ops.append(self._batch.save_state(c_name, c_state)) save_ops.append(self._batch.save_state(h_name, h_state)) return tf.group(*save_ops) def _get_tuple_state_names(num_states, base_name): """Returns state names for use with LSTM tuple state.""" state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format( i, base_name)) for i in range(num_states)] return state_names def _split_bidir_tokens(batch): tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID] # Tokens have shape [batch, time, 2] # forward and reverse have shape [batch, time]. forward, reverse = [ tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2) ] return forward, reverse def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq): """Returns input filenames for configuration. Args: phase: str, 'train', 'test', or 'valid'. bidir: bool, bidirectional model. pretrain: bool, pretraining or classification. use_seq2seq: bool, seq2seq data, only valid if pretrain=True. Returns: Tuple of filenames. Raises: ValueError: if an invalid combination of arguments is provided that does not map to any data files (e.g. pretrain=False, use_seq2seq=True). """ data_spec = (phase, bidir, pretrain, use_seq2seq) data_specs = { ('train', True, True, False): (data_utils.TRAIN_LM, data_utils.TRAIN_REV_LM), ('train', True, False, False): (data_utils.TRAIN_BD_CLASS,), ('train', False, True, False): (data_utils.TRAIN_LM,), ('train', False, True, True): (data_utils.TRAIN_SA,), ('train', False, False, False): (data_utils.TRAIN_CLASS,), ('test', True, True, False): (data_utils.TEST_LM, data_utils.TRAIN_REV_LM), ('test', True, False, False): (data_utils.TEST_BD_CLASS,), ('test', False, True, False): (data_utils.TEST_LM,), ('test', False, True, True): (data_utils.TEST_SA,), ('test', False, False, False): (data_utils.TEST_CLASS,), ('valid', True, False, False): (data_utils.VALID_BD_CLASS,), ('valid', False, False, False): (data_utils.VALID_CLASS,), } if data_spec not in data_specs: raise ValueError( 'Data specification (phase, bidir, pretrain, use_seq2seq) %s not ' 'supported' % str(data_spec)) return data_specs[data_spec] def _read_single_sequence_example(file_list, tokens_shape=None): """Reads and parses SequenceExamples from TFRecord-encoded file_list.""" tf.logging.info('Constructing TFRecordReader from files: %s', file_list) file_queue = tf.train.string_input_producer(file_list) reader = tf.TFRecordReader() seq_key, serialized_record = reader.read(file_queue) ctx, sequence = tf.parse_single_sequence_example( serialized_record, sequence_features={ data_utils.SequenceWrapper.F_TOKEN_ID: tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64), data_utils.SequenceWrapper.F_LABEL: tf.FixedLenSequenceFeature([], dtype=tf.int64), data_utils.SequenceWrapper.F_WEIGHT: tf.FixedLenSequenceFeature([], dtype=tf.float32), }) return seq_key, ctx, sequence def _read_and_batch(data_dir, fname, state_name, state_size, num_layers, unroll_steps, batch_size, bidir_input=False): """Inputs for text model. Args: data_dir: str, directory containing TFRecord files of SequenceExample. fname: str, input file name. state_name: string, key for saved state of LSTM. state_size: int, size of LSTM state. num_layers: int, the number of layers in the LSTM. unroll_steps: int, number of timesteps to unroll for TBTT. batch_size: int, batch size. bidir_input: bool, whether the input is bidirectional. If True, creates 2 states, state_name and state_name + '_reverse'. Returns: Instance of NextQueuedSequenceBatch Raises: ValueError: if file for input specification is not found. """ data_path = os.path.join(data_dir, fname) if not tf.gfile.Exists(data_path): raise ValueError('Failed to find file: %s' % data_path) tokens_shape = [2] if bidir_input else [] seq_key, ctx, sequence = _read_single_sequence_example( [data_path], tokens_shape=tokens_shape) # Set up stateful queue reader. state_names = _get_tuple_state_names(num_layers, state_name) initial_states = {} for c_state, h_state in state_names: initial_states[c_state] = tf.zeros(state_size) initial_states[h_state] = tf.zeros(state_size) if bidir_input: rev_state_names = _get_tuple_state_names(num_layers, '{}_reverse'.format(state_name)) for rev_c_state, rev_h_state in rev_state_names: initial_states[rev_c_state] = tf.zeros(state_size) initial_states[rev_h_state] = tf.zeros(state_size) batch = tf.contrib.training.batch_sequences_with_states( input_key=seq_key, input_sequences=sequence, input_context=ctx, input_length=tf.shape(sequence['token_id'])[0], initial_states=initial_states, num_unroll=unroll_steps, batch_size=batch_size, allow_small_batch=False, num_threads=4, capacity=batch_size * 10, make_keys_unique=True, make_keys_unique_seed=29392) return batch def inputs(data_dir=None, phase='train', bidir=False, pretrain=False, use_seq2seq=False, state_name='lstm', state_size=None, num_layers=0, batch_size=32, unroll_steps=100, eos_id=None): """Inputs for text model. Args: data_dir: str, directory containing TFRecord files of SequenceExample. phase: str, dataset for evaluation {'train', 'valid', 'test'}. bidir: bool, bidirectional LSTM. pretrain: bool, whether to read pretraining data or classification data. use_seq2seq: bool, whether to read seq2seq data or the language model data. state_name: string, key for saved state of LSTM. state_size: int, size of LSTM state. num_layers: int, the number of LSTM layers. batch_size: int, batch size. unroll_steps: int, number of timesteps to unroll for TBTT. eos_id: int, id of end of sequence. used for the kl weights on vat Returns: Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and reverse). """ with tf.name_scope('inputs'): filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq) if bidir and pretrain: # Bidirectional pretraining # Requires separate forward and reverse language model data. forward_fname, reverse_fname = filenames forward_batch = _read_and_batch(data_dir, forward_fname, state_name, state_size, num_layers, unroll_steps, batch_size) state_name_rev = state_name + '_reverse' reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev, state_size, num_layers, unroll_steps, batch_size) forward_input = VatxtInput( forward_batch, state_name=state_name, num_states=num_layers, eos_id=eos_id) reverse_input = VatxtInput( reverse_batch, state_name=state_name_rev, num_states=num_layers, eos_id=eos_id) return forward_input, reverse_input elif bidir: # Classifier bidirectional LSTM # Shared data source, but separate token/state streams fname, = filenames batch = _read_and_batch( data_dir, fname, state_name, state_size, num_layers, unroll_steps, batch_size, bidir_input=True) forward_tokens, reverse_tokens = _split_bidir_tokens(batch) forward_input = VatxtInput( batch, state_name=state_name, tokens=forward_tokens, num_states=num_layers) reverse_input = VatxtInput( batch, state_name=state_name + '_reverse', tokens=reverse_tokens, num_states=num_layers) return forward_input, reverse_input else: # Unidirectional LM or classifier fname, = filenames batch = _read_and_batch( data_dir, fname, state_name, state_size, num_layers, unroll_steps, batch_size, bidir_input=False) return VatxtInput( batch, state_name=state_name, num_states=num_layers, eos_id=eos_id)