|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Virtual adversarial text models.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import csv |
|
import os |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
import adversarial_losses as adv_lib |
|
import inputs as inputs_lib |
|
import layers as layers_lib |
|
|
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
|
|
|
|
flags.DEFINE_integer('num_classes', 2, 'Number of classes for classification') |
|
|
|
|
|
flags.DEFINE_string('data_dir', '/tmp/IMDB', |
|
'Directory path to preprocessed text dataset.') |
|
flags.DEFINE_string('vocab_freq_path', None, |
|
'Path to pre-calculated vocab frequency data. If ' |
|
'None, use FLAGS.data_dir/vocab_freq.txt.') |
|
flags.DEFINE_integer('batch_size', 64, 'Size of the batch.') |
|
flags.DEFINE_integer('num_timesteps', 100, 'Number of timesteps for BPTT') |
|
|
|
|
|
flags.DEFINE_bool('bidir_lstm', False, 'Whether to build a bidirectional LSTM.') |
|
flags.DEFINE_bool('single_label', True, 'Whether the sequence has a single ' |
|
'label, for optimization.') |
|
flags.DEFINE_integer('rnn_num_layers', 1, 'Number of LSTM layers.') |
|
flags.DEFINE_integer('rnn_cell_size', 512, |
|
'Number of hidden units in the LSTM.') |
|
flags.DEFINE_integer('cl_num_layers', 1, |
|
'Number of hidden layers of classification model.') |
|
flags.DEFINE_integer('cl_hidden_size', 30, |
|
'Number of hidden units in classification layer.') |
|
flags.DEFINE_integer('num_candidate_samples', -1, |
|
'Num samples used in the sampled output layer.') |
|
flags.DEFINE_bool('use_seq2seq_autoencoder', False, |
|
'If True, seq2seq auto-encoder is used to pretrain. ' |
|
'If False, standard language model is used.') |
|
|
|
|
|
flags.DEFINE_integer('embedding_dims', 256, 'Dimensions of embedded vector.') |
|
flags.DEFINE_integer('vocab_size', 86934, |
|
'The size of the vocaburary. This value ' |
|
'should be exactly same as the number of the ' |
|
'vocabulary used in dataset. Because the last ' |
|
'indexed vocabulary of the dataset preprocessed by ' |
|
'my preprocessed code, is always <eos> and here we ' |
|
'specify the <eos> with the the index.') |
|
flags.DEFINE_bool('normalize_embeddings', True, |
|
'Normalize word embeddings by vocab frequency') |
|
|
|
|
|
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate while fine-tuning.') |
|
flags.DEFINE_float('learning_rate_decay_factor', 1.0, |
|
'Learning rate decay factor') |
|
flags.DEFINE_boolean('sync_replicas', False, 'sync_replica or not') |
|
flags.DEFINE_integer('replicas_to_aggregate', 1, |
|
'The number of replicas to aggregate') |
|
|
|
|
|
flags.DEFINE_float('max_grad_norm', 1.0, |
|
'Clip the global gradient norm to this value.') |
|
flags.DEFINE_float('keep_prob_emb', 1.0, 'keep probability on embedding layer. ' |
|
'0.5 is optimal on IMDB with virtual adversarial training.') |
|
flags.DEFINE_float('keep_prob_lstm_out', 1.0, |
|
'keep probability on lstm output.') |
|
flags.DEFINE_float('keep_prob_cl_hidden', 1.0, |
|
'keep probability on classification hidden layer') |
|
|
|
|
|
def get_model(): |
|
if FLAGS.bidir_lstm: |
|
return VatxtBidirModel() |
|
else: |
|
return VatxtModel() |
|
|
|
|
|
class VatxtModel(object): |
|
"""Constructs training and evaluation graphs. |
|
|
|
Main methods: `classifier_training()`, `language_model_training()`, |
|
and `eval_graph()`. |
|
|
|
Variable reuse is a critical part of the model, both for sharing variables |
|
between the language model and the classifier, and for reusing variables for |
|
the adversarial loss calculation. To ensure correct variable reuse, all |
|
variables are created in Keras-style layers, wherein stateful layers (i.e. |
|
layers with variables) are represented as callable instances of the Layer |
|
class. Each time the Layer instance is called, it is using the same variables. |
|
|
|
All Layers are constructed in the __init__ method and reused in the various |
|
graph-building functions. |
|
""" |
|
|
|
def __init__(self, cl_logits_input_dim=None): |
|
self.global_step = tf.train.get_or_create_global_step() |
|
self.vocab_freqs = _get_vocab_freqs() |
|
|
|
|
|
self.cl_inputs = None |
|
self.lm_inputs = None |
|
|
|
|
|
self.tensors = {} |
|
|
|
|
|
|
|
|
|
self.layers = {} |
|
self.layers['embedding'] = layers_lib.Embedding( |
|
FLAGS.vocab_size, FLAGS.embedding_dims, FLAGS.normalize_embeddings, |
|
self.vocab_freqs, FLAGS.keep_prob_emb) |
|
self.layers['lstm'] = layers_lib.LSTM( |
|
FLAGS.rnn_cell_size, FLAGS.rnn_num_layers, FLAGS.keep_prob_lstm_out) |
|
self.layers['lm_loss'] = layers_lib.SoftmaxLoss( |
|
FLAGS.vocab_size, |
|
FLAGS.num_candidate_samples, |
|
self.vocab_freqs, |
|
name='LM_loss') |
|
|
|
cl_logits_input_dim = cl_logits_input_dim or FLAGS.rnn_cell_size |
|
self.layers['cl_logits'] = layers_lib.cl_logits_subgraph( |
|
[FLAGS.cl_hidden_size] * FLAGS.cl_num_layers, cl_logits_input_dim, |
|
FLAGS.num_classes, FLAGS.keep_prob_cl_hidden) |
|
|
|
@property |
|
def pretrained_variables(self): |
|
return (self.layers['embedding'].trainable_weights + |
|
self.layers['lstm'].trainable_weights) |
|
|
|
def classifier_training(self): |
|
loss = self.classifier_graph() |
|
train_op = optimize(loss, self.global_step) |
|
return train_op, loss, self.global_step |
|
|
|
def language_model_training(self): |
|
loss = self.language_model_graph() |
|
train_op = optimize(loss, self.global_step) |
|
return train_op, loss, self.global_step |
|
|
|
def classifier_graph(self): |
|
"""Constructs classifier graph from inputs to classifier loss. |
|
|
|
* Caches the VatxtInput object in `self.cl_inputs` |
|
* Caches tensors: `cl_embedded`, `cl_logits`, `cl_loss` |
|
|
|
Returns: |
|
loss: scalar float. |
|
""" |
|
inputs = _inputs('train', pretrain=False) |
|
self.cl_inputs = inputs |
|
embedded = self.layers['embedding'](inputs.tokens) |
|
self.tensors['cl_embedded'] = embedded |
|
|
|
_, next_state, logits, loss = self.cl_loss_from_embedding( |
|
embedded, return_intermediates=True) |
|
tf.summary.scalar('classification_loss', loss) |
|
self.tensors['cl_logits'] = logits |
|
self.tensors['cl_loss'] = loss |
|
|
|
if FLAGS.single_label: |
|
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) |
|
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) |
|
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) |
|
else: |
|
labels = inputs.labels |
|
weights = inputs.weights |
|
acc = layers_lib.accuracy(logits, labels, weights) |
|
tf.summary.scalar('accuracy', acc) |
|
|
|
adv_loss = (self.adversarial_loss() * tf.constant( |
|
FLAGS.adv_reg_coeff, name='adv_reg_coeff')) |
|
tf.summary.scalar('adversarial_loss', adv_loss) |
|
|
|
total_loss = loss + adv_loss |
|
|
|
with tf.control_dependencies([inputs.save_state(next_state)]): |
|
total_loss = tf.identity(total_loss) |
|
tf.summary.scalar('total_classification_loss', total_loss) |
|
return total_loss |
|
|
|
def language_model_graph(self, compute_loss=True): |
|
"""Constructs LM graph from inputs to LM loss. |
|
|
|
* Caches the VatxtInput object in `self.lm_inputs` |
|
* Caches tensors: `lm_embedded` |
|
|
|
Args: |
|
compute_loss: bool, whether to compute and return the loss or stop after |
|
the LSTM computation. |
|
|
|
Returns: |
|
loss: scalar float. |
|
""" |
|
inputs = _inputs('train', pretrain=True) |
|
self.lm_inputs = inputs |
|
return self._lm_loss(inputs, compute_loss=compute_loss) |
|
|
|
def _lm_loss(self, |
|
inputs, |
|
emb_key='lm_embedded', |
|
lstm_layer='lstm', |
|
lm_loss_layer='lm_loss', |
|
loss_name='lm_loss', |
|
compute_loss=True): |
|
embedded = self.layers['embedding'](inputs.tokens) |
|
self.tensors[emb_key] = embedded |
|
lstm_out, next_state = self.layers[lstm_layer](embedded, inputs.state, |
|
inputs.length) |
|
if compute_loss: |
|
loss = self.layers[lm_loss_layer]( |
|
[lstm_out, inputs.labels, inputs.weights]) |
|
with tf.control_dependencies([inputs.save_state(next_state)]): |
|
loss = tf.identity(loss) |
|
tf.summary.scalar(loss_name, loss) |
|
|
|
return loss |
|
|
|
def eval_graph(self, dataset='test'): |
|
"""Constructs classifier evaluation graph. |
|
|
|
Args: |
|
dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. |
|
|
|
Returns: |
|
eval_ops: dict<metric name, tuple(value, update_op)> |
|
var_restore_dict: dict mapping variable restoration names to variables. |
|
Trainable variables will be mapped to their moving average names. |
|
""" |
|
inputs = _inputs(dataset, pretrain=False) |
|
embedded = self.layers['embedding'](inputs.tokens) |
|
_, next_state, logits, _ = self.cl_loss_from_embedding( |
|
embedded, inputs=inputs, return_intermediates=True) |
|
|
|
if FLAGS.single_label: |
|
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) |
|
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) |
|
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) |
|
else: |
|
labels = inputs.labels |
|
weights = inputs.weights |
|
eval_ops = { |
|
'accuracy': |
|
tf.contrib.metrics.streaming_accuracy( |
|
layers_lib.predictions(logits), labels, weights) |
|
} |
|
|
|
with tf.control_dependencies([inputs.save_state(next_state)]): |
|
acc, acc_update = eval_ops['accuracy'] |
|
acc_update = tf.identity(acc_update) |
|
eval_ops['accuracy'] = (acc, acc_update) |
|
|
|
var_restore_dict = make_restore_average_vars_dict() |
|
return eval_ops, var_restore_dict |
|
|
|
def cl_loss_from_embedding(self, |
|
embedded, |
|
inputs=None, |
|
return_intermediates=False): |
|
"""Compute classification loss from embedding. |
|
|
|
Args: |
|
embedded: 3-D float Tensor [batch_size, num_timesteps, embedding_dim] |
|
inputs: VatxtInput, defaults to self.cl_inputs. |
|
return_intermediates: bool, whether to return intermediate tensors or only |
|
the final loss. |
|
|
|
Returns: |
|
If return_intermediates is True: |
|
lstm_out, next_state, logits, loss |
|
Else: |
|
loss |
|
""" |
|
if inputs is None: |
|
inputs = self.cl_inputs |
|
|
|
lstm_out, next_state = self.layers['lstm'](embedded, inputs.state, |
|
inputs.length) |
|
if FLAGS.single_label: |
|
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) |
|
lstm_out = tf.expand_dims(tf.gather_nd(lstm_out, indices), 1) |
|
labels = tf.expand_dims(tf.gather_nd(inputs.labels, indices), 1) |
|
weights = tf.expand_dims(tf.gather_nd(inputs.weights, indices), 1) |
|
else: |
|
labels = inputs.labels |
|
weights = inputs.weights |
|
logits = self.layers['cl_logits'](lstm_out) |
|
loss = layers_lib.classification_loss(logits, labels, weights) |
|
|
|
if return_intermediates: |
|
return lstm_out, next_state, logits, loss |
|
else: |
|
return loss |
|
|
|
def adversarial_loss(self): |
|
"""Compute adversarial loss based on FLAGS.adv_training_method.""" |
|
|
|
def random_perturbation_loss(): |
|
return adv_lib.random_perturbation_loss(self.tensors['cl_embedded'], |
|
self.cl_inputs.length, |
|
self.cl_loss_from_embedding) |
|
|
|
def adversarial_loss(): |
|
return adv_lib.adversarial_loss(self.tensors['cl_embedded'], |
|
self.tensors['cl_loss'], |
|
self.cl_loss_from_embedding) |
|
|
|
def virtual_adversarial_loss(): |
|
"""Computes virtual adversarial loss. |
|
|
|
Uses lm_inputs and constructs the language model graph if it hasn't yet |
|
been constructed. |
|
|
|
Also ensures that the LM input states are saved for LSTM state-saving |
|
BPTT. |
|
|
|
Returns: |
|
loss: float scalar. |
|
""" |
|
if self.lm_inputs is None: |
|
self.language_model_graph(compute_loss=False) |
|
|
|
def logits_from_embedding(embedded, return_next_state=False): |
|
_, next_state, logits, _ = self.cl_loss_from_embedding( |
|
embedded, inputs=self.lm_inputs, return_intermediates=True) |
|
if return_next_state: |
|
return next_state, logits |
|
else: |
|
return logits |
|
|
|
next_state, lm_cl_logits = logits_from_embedding( |
|
self.tensors['lm_embedded'], return_next_state=True) |
|
|
|
va_loss = adv_lib.virtual_adversarial_loss( |
|
lm_cl_logits, self.tensors['lm_embedded'], self.lm_inputs, |
|
logits_from_embedding) |
|
|
|
with tf.control_dependencies([self.lm_inputs.save_state(next_state)]): |
|
va_loss = tf.identity(va_loss) |
|
|
|
return va_loss |
|
|
|
def combo_loss(): |
|
return adversarial_loss() + virtual_adversarial_loss() |
|
|
|
adv_training_methods = { |
|
|
|
'rp': random_perturbation_loss, |
|
|
|
'at': adversarial_loss, |
|
|
|
'vat': virtual_adversarial_loss, |
|
|
|
'atvat': combo_loss, |
|
'': lambda: tf.constant(0.), |
|
None: lambda: tf.constant(0.), |
|
} |
|
|
|
with tf.name_scope('adversarial_loss'): |
|
return adv_training_methods[FLAGS.adv_training_method]() |
|
|
|
|
|
class VatxtBidirModel(VatxtModel): |
|
"""Extension of VatxtModel that supports bidirectional input.""" |
|
|
|
def __init__(self): |
|
super(VatxtBidirModel, |
|
self).__init__(cl_logits_input_dim=FLAGS.rnn_cell_size * 2) |
|
|
|
|
|
self.layers['lstm_reverse'] = layers_lib.LSTM( |
|
FLAGS.rnn_cell_size, |
|
FLAGS.rnn_num_layers, |
|
FLAGS.keep_prob_lstm_out, |
|
name='LSTM_Reverse') |
|
self.layers['lm_loss_reverse'] = layers_lib.SoftmaxLoss( |
|
FLAGS.vocab_size, |
|
FLAGS.num_candidate_samples, |
|
self.vocab_freqs, |
|
name='LM_loss_reverse') |
|
|
|
@property |
|
def pretrained_variables(self): |
|
variables = super(VatxtBidirModel, self).pretrained_variables |
|
variables.extend(self.layers['lstm_reverse'].trainable_weights) |
|
return variables |
|
|
|
def classifier_graph(self): |
|
"""Constructs classifier graph from inputs to classifier loss. |
|
|
|
* Caches the VatxtInput objects in `self.cl_inputs` |
|
* Caches tensors: `cl_embedded` (tuple of forward and reverse), `cl_logits`, |
|
`cl_loss` |
|
|
|
Returns: |
|
loss: scalar float. |
|
""" |
|
inputs = _inputs('train', pretrain=False, bidir=True) |
|
self.cl_inputs = inputs |
|
f_inputs, _ = inputs |
|
|
|
|
|
embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] |
|
self.tensors['cl_embedded'] = embedded |
|
|
|
_, next_states, logits, loss = self.cl_loss_from_embedding( |
|
embedded, return_intermediates=True) |
|
tf.summary.scalar('classification_loss', loss) |
|
self.tensors['cl_logits'] = logits |
|
self.tensors['cl_loss'] = loss |
|
|
|
acc = layers_lib.accuracy(logits, f_inputs.labels, f_inputs.weights) |
|
tf.summary.scalar('accuracy', acc) |
|
|
|
adv_loss = (self.adversarial_loss() * tf.constant( |
|
FLAGS.adv_reg_coeff, name='adv_reg_coeff')) |
|
tf.summary.scalar('adversarial_loss', adv_loss) |
|
|
|
total_loss = loss + adv_loss |
|
|
|
|
|
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] |
|
with tf.control_dependencies(saves): |
|
total_loss = tf.identity(total_loss) |
|
tf.summary.scalar('total_classification_loss', total_loss) |
|
return total_loss |
|
|
|
def language_model_graph(self, compute_loss=True): |
|
"""Constructs forward and reverse LM graphs from inputs to LM losses. |
|
|
|
* Caches the VatxtInput objects in `self.lm_inputs` |
|
* Caches tensors: `lm_embedded`, `lm_embedded_reverse` |
|
|
|
Args: |
|
compute_loss: bool, whether to compute and return the loss or stop after |
|
the LSTM computation. |
|
|
|
Returns: |
|
loss: scalar float, sum of forward and reverse losses. |
|
""" |
|
inputs = _inputs('train', pretrain=True, bidir=True) |
|
self.lm_inputs = inputs |
|
f_inputs, r_inputs = inputs |
|
f_loss = self._lm_loss(f_inputs, compute_loss=compute_loss) |
|
r_loss = self._lm_loss( |
|
r_inputs, |
|
emb_key='lm_embedded_reverse', |
|
lstm_layer='lstm_reverse', |
|
lm_loss_layer='lm_loss_reverse', |
|
loss_name='lm_loss_reverse', |
|
compute_loss=compute_loss) |
|
if compute_loss: |
|
return f_loss + r_loss |
|
|
|
def eval_graph(self, dataset='test'): |
|
"""Constructs classifier evaluation graph. |
|
|
|
Args: |
|
dataset: the labeled dataset to evaluate, {'train', 'test', 'valid'}. |
|
|
|
Returns: |
|
eval_ops: dict<metric name, tuple(value, update_op)> |
|
var_restore_dict: dict mapping variable restoration names to variables. |
|
Trainable variables will be mapped to their moving average names. |
|
""" |
|
inputs = _inputs(dataset, pretrain=False, bidir=True) |
|
embedded = [self.layers['embedding'](inp.tokens) for inp in inputs] |
|
_, next_states, logits, _ = self.cl_loss_from_embedding( |
|
embedded, inputs=inputs, return_intermediates=True) |
|
f_inputs, _ = inputs |
|
|
|
eval_ops = { |
|
'accuracy': |
|
tf.contrib.metrics.streaming_accuracy( |
|
layers_lib.predictions(logits), f_inputs.labels, |
|
f_inputs.weights) |
|
} |
|
|
|
|
|
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] |
|
with tf.control_dependencies(saves): |
|
acc, acc_update = eval_ops['accuracy'] |
|
acc_update = tf.identity(acc_update) |
|
eval_ops['accuracy'] = (acc, acc_update) |
|
|
|
var_restore_dict = make_restore_average_vars_dict() |
|
return eval_ops, var_restore_dict |
|
|
|
def cl_loss_from_embedding(self, |
|
embedded, |
|
inputs=None, |
|
return_intermediates=False): |
|
"""Compute classification loss from embedding. |
|
|
|
Args: |
|
embedded: Length 2 tuple of 3-D float Tensor |
|
[batch_size, num_timesteps, embedding_dim]. |
|
inputs: Length 2 tuple of VatxtInput, defaults to self.cl_inputs. |
|
return_intermediates: bool, whether to return intermediate tensors or only |
|
the final loss. |
|
|
|
Returns: |
|
If return_intermediates is True: |
|
lstm_out, next_states, logits, loss |
|
Else: |
|
loss |
|
""" |
|
if inputs is None: |
|
inputs = self.cl_inputs |
|
|
|
out = [] |
|
for (layer_name, emb, inp) in zip(['lstm', 'lstm_reverse'], embedded, |
|
inputs): |
|
out.append(self.layers[layer_name](emb, inp.state, inp.length)) |
|
lstm_outs, next_states = zip(*out) |
|
|
|
|
|
lstm_out = tf.concat(lstm_outs, 1) |
|
|
|
logits = self.layers['cl_logits'](lstm_out) |
|
f_inputs, _ = inputs |
|
loss = layers_lib.classification_loss(logits, f_inputs.labels, |
|
f_inputs.weights) |
|
|
|
if return_intermediates: |
|
return lstm_out, next_states, logits, loss |
|
else: |
|
return loss |
|
|
|
def adversarial_loss(self): |
|
"""Compute adversarial loss based on FLAGS.adv_training_method.""" |
|
|
|
def random_perturbation_loss(): |
|
return adv_lib.random_perturbation_loss_bidir(self.tensors['cl_embedded'], |
|
self.cl_inputs[0].length, |
|
self.cl_loss_from_embedding) |
|
|
|
def adversarial_loss(): |
|
return adv_lib.adversarial_loss_bidir(self.tensors['cl_embedded'], |
|
self.tensors['cl_loss'], |
|
self.cl_loss_from_embedding) |
|
|
|
def virtual_adversarial_loss(): |
|
"""Computes virtual adversarial loss. |
|
|
|
Uses lm_inputs and constructs the language model graph if it hasn't yet |
|
been constructed. |
|
|
|
Also ensures that the LM input states are saved for LSTM state-saving |
|
BPTT. |
|
|
|
Returns: |
|
loss: float scalar. |
|
""" |
|
if self.lm_inputs is None: |
|
self.language_model_graph(compute_loss=False) |
|
|
|
def logits_from_embedding(embedded, return_next_state=False): |
|
_, next_states, logits, _ = self.cl_loss_from_embedding( |
|
embedded, inputs=self.lm_inputs, return_intermediates=True) |
|
if return_next_state: |
|
return next_states, logits |
|
else: |
|
return logits |
|
|
|
lm_embedded = (self.tensors['lm_embedded'], |
|
self.tensors['lm_embedded_reverse']) |
|
next_states, lm_cl_logits = logits_from_embedding( |
|
lm_embedded, return_next_state=True) |
|
|
|
va_loss = adv_lib.virtual_adversarial_loss_bidir( |
|
lm_cl_logits, lm_embedded, self.lm_inputs, logits_from_embedding) |
|
|
|
saves = [ |
|
inp.save_state(state) |
|
for (inp, state) in zip(self.lm_inputs, next_states) |
|
] |
|
with tf.control_dependencies(saves): |
|
va_loss = tf.identity(va_loss) |
|
|
|
return va_loss |
|
|
|
def combo_loss(): |
|
return adversarial_loss() + virtual_adversarial_loss() |
|
|
|
adv_training_methods = { |
|
|
|
'rp': random_perturbation_loss, |
|
|
|
'at': adversarial_loss, |
|
|
|
'vat': virtual_adversarial_loss, |
|
|
|
'atvat': combo_loss, |
|
'': lambda: tf.constant(0.), |
|
None: lambda: tf.constant(0.), |
|
} |
|
|
|
with tf.name_scope('adversarial_loss'): |
|
return adv_training_methods[FLAGS.adv_training_method]() |
|
|
|
|
|
def _inputs(dataset='train', pretrain=False, bidir=False): |
|
return inputs_lib.inputs( |
|
data_dir=FLAGS.data_dir, |
|
phase=dataset, |
|
bidir=bidir, |
|
pretrain=pretrain, |
|
use_seq2seq=pretrain and FLAGS.use_seq2seq_autoencoder, |
|
state_size=FLAGS.rnn_cell_size, |
|
num_layers=FLAGS.rnn_num_layers, |
|
batch_size=FLAGS.batch_size, |
|
unroll_steps=FLAGS.num_timesteps, |
|
eos_id=FLAGS.vocab_size - 1) |
|
|
|
|
|
def _get_vocab_freqs(): |
|
"""Returns vocab frequencies. |
|
|
|
Returns: |
|
List of integers, length=FLAGS.vocab_size. |
|
|
|
Raises: |
|
ValueError: if the length of the frequency file is not equal to the vocab |
|
size, or if the file is not found. |
|
""" |
|
path = FLAGS.vocab_freq_path or os.path.join(FLAGS.data_dir, 'vocab_freq.txt') |
|
|
|
if tf.gfile.Exists(path): |
|
with tf.gfile.Open(path) as f: |
|
|
|
reader = csv.reader(f, quoting=csv.QUOTE_NONE) |
|
freqs = [int(row[-1]) for row in reader] |
|
if len(freqs) != FLAGS.vocab_size: |
|
raise ValueError('Frequency file length %d != vocab size %d' % |
|
(len(freqs), FLAGS.vocab_size)) |
|
else: |
|
if FLAGS.vocab_freq_path: |
|
raise ValueError('vocab_freq_path not found') |
|
freqs = [1] * FLAGS.vocab_size |
|
|
|
return freqs |
|
|
|
|
|
def make_restore_average_vars_dict(): |
|
"""Returns dict mapping moving average names to variables.""" |
|
var_restore_dict = {} |
|
variable_averages = tf.train.ExponentialMovingAverage(0.999) |
|
for v in tf.global_variables(): |
|
if v in tf.trainable_variables(): |
|
name = variable_averages.average_name(v) |
|
else: |
|
name = v.op.name |
|
var_restore_dict[name] = v |
|
return var_restore_dict |
|
|
|
|
|
def optimize(loss, global_step): |
|
return layers_lib.optimize( |
|
loss, global_step, FLAGS.max_grad_norm, FLAGS.learning_rate, |
|
FLAGS.learning_rate_decay_factor, FLAGS.sync_replicas, |
|
FLAGS.replicas_to_aggregate, FLAGS.task) |
|
|