|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input utils for virtual adversarial text classification.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
from data import data_utils |
|
|
|
|
|
class VatxtInput(object): |
|
"""Wrapper around NextQueuedSequenceBatch.""" |
|
|
|
def __init__(self, |
|
batch, |
|
state_name=None, |
|
tokens=None, |
|
num_states=0, |
|
eos_id=None): |
|
"""Construct VatxtInput. |
|
|
|
Args: |
|
batch: NextQueuedSequenceBatch. |
|
state_name: str, name of state to fetch and save. |
|
tokens: int Tensor, tokens. Defaults to batch's F_TOKEN_ID sequence. |
|
num_states: int The number of states to store. |
|
eos_id: int Id of end of Sequence. |
|
""" |
|
self._batch = batch |
|
self._state_name = state_name |
|
self._tokens = (tokens if tokens is not None else |
|
batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]) |
|
self._num_states = num_states |
|
|
|
w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT] |
|
self._weights = w |
|
|
|
l = batch.sequences[data_utils.SequenceWrapper.F_LABEL] |
|
self._labels = l |
|
|
|
|
|
self._eos_weights = None |
|
if eos_id: |
|
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32) |
|
self._eos_weights = ew |
|
|
|
@property |
|
def tokens(self): |
|
return self._tokens |
|
|
|
@property |
|
def weights(self): |
|
return self._weights |
|
|
|
@property |
|
def eos_weights(self): |
|
return self._eos_weights |
|
|
|
@property |
|
def labels(self): |
|
return self._labels |
|
|
|
@property |
|
def length(self): |
|
return self._batch.length |
|
|
|
@property |
|
def state_name(self): |
|
return self._state_name |
|
|
|
@property |
|
def state(self): |
|
|
|
state_names = _get_tuple_state_names(self._num_states, self._state_name) |
|
return tuple([ |
|
tf.contrib.rnn.LSTMStateTuple( |
|
self._batch.state(c_name), self._batch.state(h_name)) |
|
for c_name, h_name in state_names |
|
]) |
|
|
|
def save_state(self, value): |
|
|
|
state_names = _get_tuple_state_names(self._num_states, self._state_name) |
|
save_ops = [] |
|
for (c_state, h_state), (c_name, h_name) in zip(value, state_names): |
|
save_ops.append(self._batch.save_state(c_name, c_state)) |
|
save_ops.append(self._batch.save_state(h_name, h_state)) |
|
return tf.group(*save_ops) |
|
|
|
|
|
def _get_tuple_state_names(num_states, base_name): |
|
"""Returns state names for use with LSTM tuple state.""" |
|
state_names = [('{}_{}_c'.format(i, base_name), '{}_{}_h'.format( |
|
i, base_name)) for i in range(num_states)] |
|
return state_names |
|
|
|
|
|
def _split_bidir_tokens(batch): |
|
tokens = batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID] |
|
|
|
|
|
forward, reverse = [ |
|
tf.squeeze(t, axis=[2]) for t in tf.split(tokens, 2, axis=2) |
|
] |
|
return forward, reverse |
|
|
|
|
|
def _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq): |
|
"""Returns input filenames for configuration. |
|
|
|
Args: |
|
phase: str, 'train', 'test', or 'valid'. |
|
bidir: bool, bidirectional model. |
|
pretrain: bool, pretraining or classification. |
|
use_seq2seq: bool, seq2seq data, only valid if pretrain=True. |
|
|
|
Returns: |
|
Tuple of filenames. |
|
|
|
Raises: |
|
ValueError: if an invalid combination of arguments is provided that does not |
|
map to any data files (e.g. pretrain=False, use_seq2seq=True). |
|
""" |
|
data_spec = (phase, bidir, pretrain, use_seq2seq) |
|
data_specs = { |
|
('train', True, True, False): (data_utils.TRAIN_LM, |
|
data_utils.TRAIN_REV_LM), |
|
('train', True, False, False): (data_utils.TRAIN_BD_CLASS,), |
|
('train', False, True, False): (data_utils.TRAIN_LM,), |
|
('train', False, True, True): (data_utils.TRAIN_SA,), |
|
('train', False, False, False): (data_utils.TRAIN_CLASS,), |
|
('test', True, True, False): (data_utils.TEST_LM, |
|
data_utils.TRAIN_REV_LM), |
|
('test', True, False, False): (data_utils.TEST_BD_CLASS,), |
|
('test', False, True, False): (data_utils.TEST_LM,), |
|
('test', False, True, True): (data_utils.TEST_SA,), |
|
('test', False, False, False): (data_utils.TEST_CLASS,), |
|
('valid', True, False, False): (data_utils.VALID_BD_CLASS,), |
|
('valid', False, False, False): (data_utils.VALID_CLASS,), |
|
} |
|
if data_spec not in data_specs: |
|
raise ValueError( |
|
'Data specification (phase, bidir, pretrain, use_seq2seq) %s not ' |
|
'supported' % str(data_spec)) |
|
|
|
return data_specs[data_spec] |
|
|
|
|
|
def _read_single_sequence_example(file_list, tokens_shape=None): |
|
"""Reads and parses SequenceExamples from TFRecord-encoded file_list.""" |
|
tf.logging.info('Constructing TFRecordReader from files: %s', file_list) |
|
file_queue = tf.train.string_input_producer(file_list) |
|
reader = tf.TFRecordReader() |
|
seq_key, serialized_record = reader.read(file_queue) |
|
ctx, sequence = tf.parse_single_sequence_example( |
|
serialized_record, |
|
sequence_features={ |
|
data_utils.SequenceWrapper.F_TOKEN_ID: |
|
tf.FixedLenSequenceFeature(tokens_shape or [], dtype=tf.int64), |
|
data_utils.SequenceWrapper.F_LABEL: |
|
tf.FixedLenSequenceFeature([], dtype=tf.int64), |
|
data_utils.SequenceWrapper.F_WEIGHT: |
|
tf.FixedLenSequenceFeature([], dtype=tf.float32), |
|
}) |
|
return seq_key, ctx, sequence |
|
|
|
|
|
def _read_and_batch(data_dir, |
|
fname, |
|
state_name, |
|
state_size, |
|
num_layers, |
|
unroll_steps, |
|
batch_size, |
|
bidir_input=False): |
|
"""Inputs for text model. |
|
|
|
Args: |
|
data_dir: str, directory containing TFRecord files of SequenceExample. |
|
fname: str, input file name. |
|
state_name: string, key for saved state of LSTM. |
|
state_size: int, size of LSTM state. |
|
num_layers: int, the number of layers in the LSTM. |
|
unroll_steps: int, number of timesteps to unroll for TBTT. |
|
batch_size: int, batch size. |
|
bidir_input: bool, whether the input is bidirectional. If True, creates 2 |
|
states, state_name and state_name + '_reverse'. |
|
|
|
Returns: |
|
Instance of NextQueuedSequenceBatch |
|
|
|
Raises: |
|
ValueError: if file for input specification is not found. |
|
""" |
|
data_path = os.path.join(data_dir, fname) |
|
if not tf.gfile.Exists(data_path): |
|
raise ValueError('Failed to find file: %s' % data_path) |
|
|
|
tokens_shape = [2] if bidir_input else [] |
|
seq_key, ctx, sequence = _read_single_sequence_example( |
|
[data_path], tokens_shape=tokens_shape) |
|
|
|
state_names = _get_tuple_state_names(num_layers, state_name) |
|
initial_states = {} |
|
for c_state, h_state in state_names: |
|
initial_states[c_state] = tf.zeros(state_size) |
|
initial_states[h_state] = tf.zeros(state_size) |
|
if bidir_input: |
|
rev_state_names = _get_tuple_state_names(num_layers, |
|
'{}_reverse'.format(state_name)) |
|
for rev_c_state, rev_h_state in rev_state_names: |
|
initial_states[rev_c_state] = tf.zeros(state_size) |
|
initial_states[rev_h_state] = tf.zeros(state_size) |
|
batch = tf.contrib.training.batch_sequences_with_states( |
|
input_key=seq_key, |
|
input_sequences=sequence, |
|
input_context=ctx, |
|
input_length=tf.shape(sequence['token_id'])[0], |
|
initial_states=initial_states, |
|
num_unroll=unroll_steps, |
|
batch_size=batch_size, |
|
allow_small_batch=False, |
|
num_threads=4, |
|
capacity=batch_size * 10, |
|
make_keys_unique=True, |
|
make_keys_unique_seed=29392) |
|
return batch |
|
|
|
|
|
def inputs(data_dir=None, |
|
phase='train', |
|
bidir=False, |
|
pretrain=False, |
|
use_seq2seq=False, |
|
state_name='lstm', |
|
state_size=None, |
|
num_layers=0, |
|
batch_size=32, |
|
unroll_steps=100, |
|
eos_id=None): |
|
"""Inputs for text model. |
|
|
|
Args: |
|
data_dir: str, directory containing TFRecord files of SequenceExample. |
|
phase: str, dataset for evaluation {'train', 'valid', 'test'}. |
|
bidir: bool, bidirectional LSTM. |
|
pretrain: bool, whether to read pretraining data or classification data. |
|
use_seq2seq: bool, whether to read seq2seq data or the language model data. |
|
state_name: string, key for saved state of LSTM. |
|
state_size: int, size of LSTM state. |
|
num_layers: int, the number of LSTM layers. |
|
batch_size: int, batch size. |
|
unroll_steps: int, number of timesteps to unroll for TBTT. |
|
eos_id: int, id of end of sequence. used for the kl weights on vat |
|
Returns: |
|
Instance of VatxtInput (x2 if bidir=True and pretrain=True, i.e. forward and |
|
reverse). |
|
""" |
|
with tf.name_scope('inputs'): |
|
filenames = _filenames_for_data_spec(phase, bidir, pretrain, use_seq2seq) |
|
|
|
if bidir and pretrain: |
|
|
|
|
|
forward_fname, reverse_fname = filenames |
|
forward_batch = _read_and_batch(data_dir, forward_fname, state_name, |
|
state_size, num_layers, unroll_steps, |
|
batch_size) |
|
state_name_rev = state_name + '_reverse' |
|
reverse_batch = _read_and_batch(data_dir, reverse_fname, state_name_rev, |
|
state_size, num_layers, unroll_steps, |
|
batch_size) |
|
forward_input = VatxtInput( |
|
forward_batch, |
|
state_name=state_name, |
|
num_states=num_layers, |
|
eos_id=eos_id) |
|
reverse_input = VatxtInput( |
|
reverse_batch, |
|
state_name=state_name_rev, |
|
num_states=num_layers, |
|
eos_id=eos_id) |
|
return forward_input, reverse_input |
|
|
|
elif bidir: |
|
|
|
|
|
fname, = filenames |
|
batch = _read_and_batch( |
|
data_dir, |
|
fname, |
|
state_name, |
|
state_size, |
|
num_layers, |
|
unroll_steps, |
|
batch_size, |
|
bidir_input=True) |
|
forward_tokens, reverse_tokens = _split_bidir_tokens(batch) |
|
forward_input = VatxtInput( |
|
batch, |
|
state_name=state_name, |
|
tokens=forward_tokens, |
|
num_states=num_layers) |
|
reverse_input = VatxtInput( |
|
batch, |
|
state_name=state_name + '_reverse', |
|
tokens=reverse_tokens, |
|
num_states=num_layers) |
|
return forward_input, reverse_input |
|
else: |
|
|
|
fname, = filenames |
|
batch = _read_and_batch( |
|
data_dir, |
|
fname, |
|
state_name, |
|
state_size, |
|
num_layers, |
|
unroll_steps, |
|
batch_size, |
|
bidir_input=False) |
|
return VatxtInput( |
|
batch, state_name=state_name, num_states=num_layers, eos_id=eos_id) |
|
|