|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model utilities.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
import numpy as np |
|
|
|
import tensorflow as tf |
|
from model_utils import variable_mapping |
|
|
|
FLAGS = tf.app.flags.FLAGS |
|
|
|
|
|
def generate_mask(): |
|
"""Generate the mask to be fed into the model.""" |
|
if FLAGS.mask_strategy == 'random': |
|
p = np.random.choice( |
|
[True, False], |
|
size=[FLAGS.batch_size, FLAGS.sequence_length], |
|
p=[FLAGS.is_present_rate, 1. - FLAGS.is_present_rate]) |
|
|
|
elif FLAGS.mask_strategy == 'contiguous': |
|
masked_length = int((1 - FLAGS.is_present_rate) * FLAGS.sequence_length) - 1 |
|
|
|
start_mask = np.random.randint( |
|
1, FLAGS.sequence_length - masked_length + 1, size=FLAGS.batch_size) |
|
p = np.full([FLAGS.batch_size, FLAGS.sequence_length], True, dtype=bool) |
|
|
|
|
|
for i, index in enumerate(start_mask): |
|
p[i, index:index + masked_length] = False |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
return p |
|
|
|
|
|
def assign_percent_real(session, percent_real_update, new_rate, current_rate): |
|
"""Run assign operation where the we load the current_rate of percent |
|
real into a Tensorflow variable. |
|
|
|
Args: |
|
session: Current tf.Session. |
|
percent_real_update: tf.assign operation. |
|
new_rate: tf.placeholder for the new rate. |
|
current_rate: Percent of tokens that are currently real. Fake tokens |
|
are the ones being imputed by the Generator. |
|
""" |
|
session.run(percent_real_update, feed_dict={new_rate: current_rate}) |
|
|
|
|
|
def assign_learning_rate(session, lr_update, lr_placeholder, new_lr): |
|
"""Run assign operation where the we load the current_rate of percent |
|
real into a Tensorflow variable. |
|
|
|
Args: |
|
session: Current tf.Session. |
|
lr_update: tf.assign operation. |
|
lr_placeholder: tf.placeholder for the new learning rate. |
|
new_lr: New learning rate to use. |
|
""" |
|
session.run(lr_update, feed_dict={lr_placeholder: new_lr}) |
|
|
|
|
|
def clip_weights(variables, c_lower, c_upper): |
|
"""Clip a list of weights to be within a certain range. |
|
|
|
Args: |
|
variables: List of tf.Variable weights. |
|
c_lower: Lower bound for weights. |
|
c_upper: Upper bound for weights. |
|
""" |
|
clip_ops = [] |
|
|
|
for var in variables: |
|
clipped_var = tf.clip_by_value(var, c_lower, c_upper) |
|
|
|
clip_ops.append(tf.assign(var, clipped_var)) |
|
return tf.group(*clip_ops) |
|
|
|
|
|
def retrieve_init_savers(hparams): |
|
"""Retrieve a dictionary of all the initial savers for the models. |
|
|
|
Args: |
|
hparams: MaskGAN hyperparameters. |
|
""" |
|
|
|
init_savers = {} |
|
|
|
|
|
if FLAGS.maskgan_ckpt: |
|
gen_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('gen') |
|
] |
|
init_saver = tf.train.Saver(var_list=gen_vars) |
|
init_savers['init_saver'] = init_saver |
|
|
|
|
|
|
|
if FLAGS.discriminator_model == 'seq2seq_vd': |
|
dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams) |
|
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) |
|
init_savers['dis_init_saver'] = dis_init_saver |
|
|
|
|
|
if FLAGS.language_model_ckpt_dir: |
|
if FLAGS.maskgan_ckpt is None: |
|
|
|
if FLAGS.generator_model == 'rnn_nas': |
|
gen_variable_maps = variable_mapping.rnn_nas(hparams, model='gen') |
|
gen_init_saver = tf.train.Saver(var_list=gen_variable_maps) |
|
init_savers['gen_init_saver'] = gen_init_saver |
|
|
|
elif FLAGS.generator_model == 'seq2seq_nas': |
|
|
|
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas( |
|
hparams) |
|
gen_encoder_init_saver = tf.train.Saver( |
|
var_list=gen_encoder_variable_maps) |
|
|
|
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas( |
|
hparams) |
|
gen_decoder_init_saver = tf.train.Saver( |
|
var_list=gen_decoder_variable_maps) |
|
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver |
|
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver |
|
|
|
|
|
elif (FLAGS.generator_model == 'seq2seq_zaremba' or |
|
FLAGS.generator_model == 'seq2seq_vd'): |
|
|
|
gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq( |
|
hparams) |
|
gen_encoder_init_saver = tf.train.Saver( |
|
var_list=gen_encoder_variable_maps) |
|
|
|
gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq( |
|
hparams) |
|
gen_decoder_init_saver = tf.train.Saver( |
|
var_list=gen_decoder_variable_maps) |
|
init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver |
|
init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if FLAGS.discriminator_model == 'rnn_nas': |
|
dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis') |
|
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) |
|
init_savers['dis_init_saver'] = dis_init_saver |
|
|
|
|
|
elif (FLAGS.discriminator_model == 'rnn_zaremba' or |
|
FLAGS.discriminator_model == 'rnn_vd'): |
|
dis_variable_maps = variable_mapping.rnn_zaremba(hparams, model='dis') |
|
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) |
|
init_savers['dis_init_saver'] = dis_init_saver |
|
|
|
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or |
|
FLAGS.discriminator_model == 'bidirectional_vd'): |
|
dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(hparams) |
|
dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(hparams) |
|
|
|
dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps) |
|
dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps) |
|
init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver |
|
init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver |
|
|
|
elif FLAGS.discriminator_model == 'cnn': |
|
dis_variable_maps = variable_mapping.cnn() |
|
dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) |
|
init_savers['dis_init_saver'] = dis_init_saver |
|
|
|
elif FLAGS.discriminator_model == 'seq2seq_vd': |
|
|
|
dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(hparams) |
|
dis_encoder_init_saver = tf.train.Saver( |
|
var_list=dis_encoder_variable_maps) |
|
|
|
dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(hparams) |
|
dis_decoder_init_saver = tf.train.Saver( |
|
var_list=dis_decoder_variable_maps) |
|
init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver |
|
init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver |
|
|
|
return init_savers |
|
|
|
|
|
def init_fn(init_savers, sess): |
|
"""The init_fn to be passed to the Supervisor. |
|
|
|
Args: |
|
init_savers: Dictionary of init_savers. 'init_saver_name': init_saver. |
|
sess: tf.Session. |
|
""" |
|
|
|
if FLAGS.maskgan_ckpt: |
|
print('Restoring Generator from %s.' % FLAGS.maskgan_ckpt) |
|
tf.logging.info('Restoring Generator from %s.' % FLAGS.maskgan_ckpt) |
|
print('Asserting Generator is a seq2seq-variant.') |
|
tf.logging.info('Asserting Generator is a seq2seq-variant.') |
|
assert FLAGS.generator_model.startswith('seq2seq') |
|
init_saver = init_savers['init_saver'] |
|
init_saver.restore(sess, FLAGS.maskgan_ckpt) |
|
|
|
|
|
|
|
if FLAGS.discriminator_model == 'seq2seq_vd': |
|
print('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt) |
|
tf.logging.info('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt) |
|
dis_init_saver = init_savers['dis_init_saver'] |
|
dis_init_saver.restore(sess, FLAGS.maskgan_ckpt) |
|
|
|
|
|
if FLAGS.language_model_ckpt_dir: |
|
if FLAGS.maskgan_ckpt is None: |
|
|
|
if FLAGS.generator_model == 'rnn_nas': |
|
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) |
|
print('Restoring Generator from %s.' % load_ckpt) |
|
tf.logging.info('Restoring Generator from %s.' % load_ckpt) |
|
gen_init_saver = init_savers['gen_init_saver'] |
|
gen_init_saver.restore(sess, load_ckpt) |
|
|
|
elif FLAGS.generator_model.startswith('seq2seq'): |
|
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) |
|
print('Restoring Generator from %s.' % load_ckpt) |
|
tf.logging.info('Restoring Generator from %s.' % load_ckpt) |
|
gen_encoder_init_saver = init_savers['gen_encoder_init_saver'] |
|
gen_decoder_init_saver = init_savers['gen_decoder_init_saver'] |
|
gen_encoder_init_saver.restore(sess, load_ckpt) |
|
gen_decoder_init_saver.restore(sess, load_ckpt) |
|
|
|
|
|
if (FLAGS.discriminator_model == 'rnn_nas' or |
|
FLAGS.discriminator_model == 'rnn_zaremba' or |
|
FLAGS.discriminator_model == 'rnn_vd' or |
|
FLAGS.discriminator_model == 'cnn'): |
|
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) |
|
print('Restoring Discriminator from %s.' % load_ckpt) |
|
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt) |
|
dis_init_saver = init_savers['dis_init_saver'] |
|
dis_init_saver.restore(sess, load_ckpt) |
|
|
|
elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or |
|
FLAGS.discriminator_model == 'bidirectional_vd'): |
|
assert FLAGS.language_model_ckpt_dir_reversed is not None, ( |
|
'Need a reversed directory to fill in the backward components.') |
|
load_fwd_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) |
|
load_bwd_ckpt = tf.train.latest_checkpoint( |
|
FLAGS.language_model_ckpt_dir_reversed) |
|
print('Restoring Discriminator from %s and %s.' % (load_fwd_ckpt, |
|
load_bwd_ckpt)) |
|
tf.logging.info('Restoring Discriminator from %s and %s.' % |
|
(load_fwd_ckpt, load_bwd_ckpt)) |
|
dis_fwd_init_saver = init_savers['dis_fwd_init_saver'] |
|
dis_bwd_init_saver = init_savers['dis_bwd_init_saver'] |
|
dis_fwd_init_saver.restore(sess, load_fwd_ckpt) |
|
dis_bwd_init_saver.restore(sess, load_bwd_ckpt) |
|
|
|
elif FLAGS.discriminator_model == 'seq2seq_vd': |
|
load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) |
|
print('Restoring Discriminator from %s.' % load_ckpt) |
|
tf.logging.info('Restoring Discriminator from %s.' % load_ckpt) |
|
dis_encoder_init_saver = init_savers['dis_encoder_init_saver'] |
|
dis_decoder_init_saver = init_savers['dis_decoder_init_saver'] |
|
dis_encoder_init_saver.restore(sess, load_ckpt) |
|
dis_decoder_init_saver.restore(sess, load_ckpt) |
|
|
|
else: |
|
return |
|
|