# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports import numpy as np import tensorflow as tf from model_utils import variable_mapping FLAGS = tf.app.flags.FLAGS def generate_mask(): """Generate the mask to be fed into the model.""" if FLAGS.mask_strategy == 'random': p = np.random.choice( [True, False], size=[FLAGS.batch_size, FLAGS.sequence_length], p=[FLAGS.is_present_rate, 1. - FLAGS.is_present_rate]) elif FLAGS.mask_strategy == 'contiguous': masked_length = int((1 - FLAGS.is_present_rate) * FLAGS.sequence_length) - 1 # Determine location to start masking. start_mask = np.random.randint( 1, FLAGS.sequence_length - masked_length + 1, size=FLAGS.batch_size) p = np.full([FLAGS.batch_size, FLAGS.sequence_length], True, dtype=bool) # Create contiguous masked section to be False. for i, index in enumerate(start_mask): p[i, index:index + masked_length] = False else: raise NotImplementedError return p def assign_percent_real(session, percent_real_update, new_rate, current_rate): """Run assign operation where the we load the current_rate of percent real into a Tensorflow variable. Args: session: Current tf.Session. percent_real_update: tf.assign operation. new_rate: tf.placeholder for the new rate. current_rate: Percent of tokens that are currently real. Fake tokens are the ones being imputed by the Generator. """ session.run(percent_real_update, feed_dict={new_rate: current_rate}) def assign_learning_rate(session, lr_update, lr_placeholder, new_lr): """Run assign operation where the we load the current_rate of percent real into a Tensorflow variable. Args: session: Current tf.Session. lr_update: tf.assign operation. lr_placeholder: tf.placeholder for the new learning rate. new_lr: New learning rate to use. """ session.run(lr_update, feed_dict={lr_placeholder: new_lr}) def clip_weights(variables, c_lower, c_upper): """Clip a list of weights to be within a certain range. Args: variables: List of tf.Variable weights. c_lower: Lower bound for weights. c_upper: Upper bound for weights. """ clip_ops = [] for var in variables: clipped_var = tf.clip_by_value(var, c_lower, c_upper) clip_ops.append(tf.assign(var, clipped_var)) return tf.group(*clip_ops) def retrieve_init_savers(hparams): """Retrieve a dictionary of all the initial savers for the models. Args: hparams: MaskGAN hyperparameters. """ ## Dictionary of init savers. init_savers = {} ## Load Generator weights from MaskGAN checkpoint. if FLAGS.maskgan_ckpt: gen_vars = [ v for v in tf.trainable_variables() if v.op.name.startswith('gen') ] init_saver = tf.train.Saver(var_list=gen_vars) init_savers['init_saver'] = init_saver ## Load the Discriminator weights from the MaskGAN checkpoint if # the weights are compatible. if FLAGS.discriminator_model == 'seq2seq_vd': dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams) dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) init_savers['dis_init_saver'] = dis_init_saver ## Load weights from language model checkpoint. if FLAGS.language_model_ckpt_dir: if FLAGS.maskgan_ckpt is None: ## Generator Variables/Savers. if FLAGS.generator_model == 'rnn_nas': gen_variable_maps = variable_mapping.rnn_nas(hparams, model='gen') gen_init_saver = tf.train.Saver(var_list=gen_variable_maps) init_savers['gen_init_saver'] = gen_init_saver elif FLAGS.generator_model == 'seq2seq_nas': # Encoder. gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas( hparams) gen_encoder_init_saver = tf.train.Saver( var_list=gen_encoder_variable_maps) # Decoder. gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas( hparams) gen_decoder_init_saver = tf.train.Saver( var_list=gen_decoder_variable_maps) init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver # seq2seq_vd derived from the same code base as seq2seq_zaremba. elif (FLAGS.generator_model == 'seq2seq_zaremba' or FLAGS.generator_model == 'seq2seq_vd'): # Encoder. gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq( hparams) gen_encoder_init_saver = tf.train.Saver( var_list=gen_encoder_variable_maps) # Decoder. gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq( hparams) gen_decoder_init_saver = tf.train.Saver( var_list=gen_decoder_variable_maps) init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver else: raise NotImplementedError ## Discriminator Variables/Savers. if FLAGS.discriminator_model == 'rnn_nas': dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis') dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) init_savers['dis_init_saver'] = dis_init_saver # rnn_vd derived from the same code base as rnn_zaremba. elif (FLAGS.discriminator_model == 'rnn_zaremba' or FLAGS.discriminator_model == 'rnn_vd'): dis_variable_maps = variable_mapping.rnn_zaremba(hparams, model='dis') dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) init_savers['dis_init_saver'] = dis_init_saver elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or FLAGS.discriminator_model == 'bidirectional_vd'): dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(hparams) dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(hparams) # Savers for the forward/backward Discriminator components. dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps) dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps) init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver elif FLAGS.discriminator_model == 'cnn': dis_variable_maps = variable_mapping.cnn() dis_init_saver = tf.train.Saver(var_list=dis_variable_maps) init_savers['dis_init_saver'] = dis_init_saver elif FLAGS.discriminator_model == 'seq2seq_vd': # Encoder. dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(hparams) dis_encoder_init_saver = tf.train.Saver( var_list=dis_encoder_variable_maps) # Decoder. dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(hparams) dis_decoder_init_saver = tf.train.Saver( var_list=dis_decoder_variable_maps) init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver return init_savers def init_fn(init_savers, sess): """The init_fn to be passed to the Supervisor. Args: init_savers: Dictionary of init_savers. 'init_saver_name': init_saver. sess: tf.Session. """ ## Load Generator weights from MaskGAN checkpoint. if FLAGS.maskgan_ckpt: print('Restoring Generator from %s.' % FLAGS.maskgan_ckpt) tf.logging.info('Restoring Generator from %s.' % FLAGS.maskgan_ckpt) print('Asserting Generator is a seq2seq-variant.') tf.logging.info('Asserting Generator is a seq2seq-variant.') assert FLAGS.generator_model.startswith('seq2seq') init_saver = init_savers['init_saver'] init_saver.restore(sess, FLAGS.maskgan_ckpt) ## Load the Discriminator weights from the MaskGAN checkpoint if # the weights are compatible. if FLAGS.discriminator_model == 'seq2seq_vd': print('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt) tf.logging.info('Restoring Discriminator from %s.' % FLAGS.maskgan_ckpt) dis_init_saver = init_savers['dis_init_saver'] dis_init_saver.restore(sess, FLAGS.maskgan_ckpt) ## Load weights from language model checkpoint. if FLAGS.language_model_ckpt_dir: if FLAGS.maskgan_ckpt is None: ## Generator Models. if FLAGS.generator_model == 'rnn_nas': load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) print('Restoring Generator from %s.' % load_ckpt) tf.logging.info('Restoring Generator from %s.' % load_ckpt) gen_init_saver = init_savers['gen_init_saver'] gen_init_saver.restore(sess, load_ckpt) elif FLAGS.generator_model.startswith('seq2seq'): load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) print('Restoring Generator from %s.' % load_ckpt) tf.logging.info('Restoring Generator from %s.' % load_ckpt) gen_encoder_init_saver = init_savers['gen_encoder_init_saver'] gen_decoder_init_saver = init_savers['gen_decoder_init_saver'] gen_encoder_init_saver.restore(sess, load_ckpt) gen_decoder_init_saver.restore(sess, load_ckpt) ## Discriminator Models. if (FLAGS.discriminator_model == 'rnn_nas' or FLAGS.discriminator_model == 'rnn_zaremba' or FLAGS.discriminator_model == 'rnn_vd' or FLAGS.discriminator_model == 'cnn'): load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) print('Restoring Discriminator from %s.' % load_ckpt) tf.logging.info('Restoring Discriminator from %s.' % load_ckpt) dis_init_saver = init_savers['dis_init_saver'] dis_init_saver.restore(sess, load_ckpt) elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or FLAGS.discriminator_model == 'bidirectional_vd'): assert FLAGS.language_model_ckpt_dir_reversed is not None, ( 'Need a reversed directory to fill in the backward components.') load_fwd_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) load_bwd_ckpt = tf.train.latest_checkpoint( FLAGS.language_model_ckpt_dir_reversed) print('Restoring Discriminator from %s and %s.' % (load_fwd_ckpt, load_bwd_ckpt)) tf.logging.info('Restoring Discriminator from %s and %s.' % (load_fwd_ckpt, load_bwd_ckpt)) dis_fwd_init_saver = init_savers['dis_fwd_init_saver'] dis_bwd_init_saver = init_savers['dis_bwd_init_saver'] dis_fwd_init_saver.restore(sess, load_fwd_ckpt) dis_bwd_init_saver.restore(sess, load_bwd_ckpt) elif FLAGS.discriminator_model == 'seq2seq_vd': load_ckpt = tf.train.latest_checkpoint(FLAGS.language_model_ckpt_dir) print('Restoring Discriminator from %s.' % load_ckpt) tf.logging.info('Restoring Discriminator from %s.' % load_ckpt) dis_encoder_init_saver = init_savers['dis_encoder_init_saver'] dis_decoder_init_saver = init_savers['dis_decoder_init_saver'] dis_encoder_init_saver.restore(sess, load_ckpt) dis_decoder_init_saver.restore(sess, load_ckpt) else: return