|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model optimization.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
FLAGS = tf.app.flags.FLAGS |
|
|
|
|
|
def create_dis_pretrain_op(hparams, dis_loss, global_step): |
|
"""Create a train op for pretraining.""" |
|
with tf.name_scope('pretrain_generator'): |
|
optimizer = tf.train.AdamOptimizer(hparams.dis_pretrain_learning_rate) |
|
dis_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('dis') |
|
] |
|
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding: |
|
shared_embedding = [ |
|
v for v in tf.trainable_variables() |
|
if v.op.name == 'gen/decoder/rnn/embedding' |
|
][0] |
|
dis_vars.append(shared_embedding) |
|
dis_grads = tf.gradients(dis_loss, dis_vars) |
|
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads, |
|
FLAGS.grad_clipping) |
|
dis_pretrain_op = optimizer.apply_gradients( |
|
zip(dis_grads_clipped, dis_vars), global_step=global_step) |
|
return dis_pretrain_op |
|
|
|
|
|
def create_gen_pretrain_op(hparams, cross_entropy_loss, global_step): |
|
"""Create a train op for pretraining.""" |
|
with tf.name_scope('pretrain_generator'): |
|
optimizer = tf.train.AdamOptimizer(hparams.gen_pretrain_learning_rate) |
|
gen_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('gen') |
|
] |
|
gen_grads = tf.gradients(cross_entropy_loss, gen_vars) |
|
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, |
|
FLAGS.grad_clipping) |
|
gen_pretrain_op = optimizer.apply_gradients( |
|
zip(gen_grads_clipped, gen_vars), global_step=global_step) |
|
return gen_pretrain_op |
|
|
|
|
|
def create_gen_train_op(hparams, learning_rate, gen_loss, global_step, mode): |
|
"""Create Generator train op.""" |
|
del hparams |
|
with tf.name_scope('train_generator'): |
|
if FLAGS.generator_optimizer == 'sgd': |
|
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate) |
|
elif FLAGS.generator_optimizer == 'adam': |
|
gen_optimizer = tf.train.AdamOptimizer(learning_rate) |
|
else: |
|
raise NotImplementedError |
|
gen_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('gen') |
|
] |
|
print('Optimizing Generator vars.') |
|
for v in gen_vars: |
|
print(v) |
|
if mode == 'MINIMIZE': |
|
gen_grads = tf.gradients(gen_loss, gen_vars) |
|
elif mode == 'MAXIMIZE': |
|
gen_grads = tf.gradients(-gen_loss, gen_vars) |
|
else: |
|
raise ValueError("Must be one of 'MINIMIZE' or 'MAXIMIZE'") |
|
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, |
|
FLAGS.grad_clipping) |
|
gen_train_op = gen_optimizer.apply_gradients( |
|
zip(gen_grads_clipped, gen_vars), global_step=global_step) |
|
return gen_train_op, gen_grads_clipped, gen_vars |
|
|
|
|
|
def create_reinforce_gen_train_op(hparams, learning_rate, final_gen_reward, |
|
averages_op, global_step): |
|
"""Create the Generator train_op when using REINFORCE. |
|
|
|
Args: |
|
hparams: MaskGAN hyperparameters. |
|
learning_rate: tf.Variable scalar learning rate. |
|
final_gen_objective: Scalar final REINFORCE objective for the sequence. |
|
averages_op: ExponentialMovingAverage apply average op to |
|
maintain the baseline. |
|
global_step: global_step tf.Variable. |
|
|
|
Returns: |
|
gen_train_op: Generator training op. |
|
""" |
|
del hparams |
|
with tf.name_scope('train_generator'): |
|
if FLAGS.generator_optimizer == 'sgd': |
|
gen_optimizer = tf.train.GradientDescentOptimizer(learning_rate) |
|
elif FLAGS.generator_optimizer == 'adam': |
|
gen_optimizer = tf.train.AdamOptimizer(learning_rate) |
|
else: |
|
raise NotImplementedError |
|
gen_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('gen') |
|
] |
|
print('\nOptimizing Generator vars:') |
|
for v in gen_vars: |
|
print(v) |
|
|
|
|
|
gen_grads = tf.gradients(-final_gen_reward, gen_vars) |
|
gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, |
|
FLAGS.grad_clipping) |
|
maximize_op = gen_optimizer.apply_gradients( |
|
zip(gen_grads_clipped, gen_vars), global_step=global_step) |
|
|
|
|
|
if averages_op: |
|
gen_train_op = tf.group(maximize_op, averages_op) |
|
else: |
|
gen_train_op = maximize_op |
|
|
|
return [gen_train_op, gen_grads, gen_vars] |
|
|
|
|
|
def create_dis_train_op(hparams, dis_loss, global_step): |
|
"""Create Discriminator train op.""" |
|
with tf.name_scope('train_discriminator'): |
|
dis_optimizer = tf.train.AdamOptimizer(hparams.dis_learning_rate) |
|
dis_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('dis') |
|
] |
|
if FLAGS.dis_update_share_embedding and FLAGS.dis_share_embedding: |
|
shared_embedding = [ |
|
v for v in tf.trainable_variables() |
|
if v.op.name == 'gen/decoder/rnn/embedding' |
|
][0] |
|
dis_vars.append(shared_embedding) |
|
print('\nOptimizing Discriminator vars:') |
|
for v in dis_vars: |
|
print(v) |
|
dis_grads = tf.gradients(dis_loss, dis_vars) |
|
dis_grads_clipped, _ = tf.clip_by_global_norm(dis_grads, |
|
FLAGS.grad_clipping) |
|
dis_train_op = dis_optimizer.apply_gradients( |
|
zip(dis_grads_clipped, dis_vars), global_step=global_step) |
|
return dis_train_op, dis_grads_clipped, dis_vars |
|
|
|
|
|
def create_critic_train_op(hparams, critic_loss, global_step): |
|
"""Create Discriminator train op.""" |
|
with tf.name_scope('train_critic'): |
|
critic_optimizer = tf.train.AdamOptimizer(hparams.critic_learning_rate) |
|
output_vars = [ |
|
v for v in tf.trainable_variables() if v.op.name.startswith('critic') |
|
] |
|
|
|
if FLAGS.critic_update_dis_vars: |
|
if FLAGS.discriminator_model == 'bidirectional_vd': |
|
critic_vars = [ |
|
v for v in tf.trainable_variables() |
|
if v.op.name.startswith('dis/rnn') |
|
] |
|
elif FLAGS.discriminator_model == 'seq2seq_vd': |
|
critic_vars = [ |
|
v for v in tf.trainable_variables() |
|
if v.op.name.startswith('dis/decoder/rnn/multi_rnn_cell') |
|
] |
|
critic_vars.extend(output_vars) |
|
else: |
|
critic_vars = output_vars |
|
print('\nOptimizing Critic vars:') |
|
for v in critic_vars: |
|
print(v) |
|
critic_grads = tf.gradients(critic_loss, critic_vars) |
|
critic_grads_clipped, _ = tf.clip_by_global_norm(critic_grads, |
|
FLAGS.grad_clipping) |
|
critic_train_op = critic_optimizer.apply_gradients( |
|
zip(critic_grads_clipped, critic_vars), global_step=global_step) |
|
return critic_train_op, critic_grads_clipped, critic_vars |
|
|