|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train the cross convolutional model.""" |
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import model as cross_conv_model |
|
import reader |
|
|
|
FLAGS = tf.flags.FLAGS |
|
tf.flags.DEFINE_string('master', '', 'Session address.') |
|
tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.') |
|
tf.flags.DEFINE_string('data_filepattern', '', |
|
'training data file pattern.') |
|
tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.') |
|
tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.') |
|
tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image') |
|
tf.flags.DEFINE_float('scale', 10.0, |
|
'Scale the image after norm_scale and move the diff ' |
|
'to the positive realm.') |
|
tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.') |
|
tf.flags.DEFINE_float('learning_rate', 0.8, 'Learning rate.') |
|
tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.') |
|
tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.') |
|
tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.') |
|
|
|
slim = tf.contrib.slim |
|
|
|
|
|
def _Train(): |
|
params = dict() |
|
params['batch_size'] = FLAGS.batch_size |
|
params['seq_len'] = FLAGS.sequence_length |
|
params['image_size'] = FLAGS.image_size |
|
params['is_training'] = True |
|
params['norm_scale'] = FLAGS.norm_scale |
|
params['scale'] = FLAGS.scale |
|
params['learning_rate'] = FLAGS.learning_rate |
|
params['l2_loss'] = FLAGS.l2_loss |
|
params['reconstr_loss'] = FLAGS.reconstr_loss |
|
params['kl_loss'] = FLAGS.kl_loss |
|
|
|
train_dir = os.path.join(FLAGS.log_root, 'train') |
|
|
|
images = reader.ReadInput(FLAGS.data_filepattern, shuffle=True, params=params) |
|
images *= params['scale'] |
|
|
|
image_diff_list = reader.SequenceToImageAndDiff(images) |
|
model = cross_conv_model.CrossConvModel(image_diff_list, params) |
|
model.Build() |
|
tf.contrib.tfprof.model_analyzer.print_model_analysis(tf.get_default_graph()) |
|
|
|
summary_writer = tf.summary.FileWriter(train_dir) |
|
sv = tf.train.Supervisor(logdir=FLAGS.log_root, |
|
summary_op=None, |
|
is_chief=True, |
|
save_model_secs=60, |
|
global_step=model.global_step) |
|
sess = sv.prepare_or_wait_for_session( |
|
FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)) |
|
|
|
total_loss = 0.0 |
|
step = 0 |
|
sample_z_mean = np.zeros(model.z_mean.get_shape().as_list()) |
|
sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list()) |
|
sample_step = 0 |
|
|
|
while True: |
|
_, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run( |
|
[model.train_op, model.loss, model.global_step, |
|
model.summary_op, |
|
model.z_mean, model.z_stddev_log]) |
|
|
|
sample_z_mean += z_mean |
|
sample_z_stddev_log += z_stddev_log |
|
total_loss += loss_val |
|
step += 1 |
|
sample_step += 1 |
|
|
|
if step % 100 == 0: |
|
summary_writer.add_summary(summaries, total_steps) |
|
sys.stderr.write('step: %d, loss: %f\n' % |
|
(total_steps, total_loss / step)) |
|
total_loss = 0.0 |
|
step = 0 |
|
|
|
|
|
|
|
if sample_step % 10000 == 0: |
|
with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'), 'w') as f: |
|
np.save(f, sample_z_mean / sample_step) |
|
with tf.gfile.Open( |
|
os.path.join(FLAGS.log_root, 'z_stddev_log.npy'), 'w') as f: |
|
np.save(f, sample_z_stddev_log / sample_step) |
|
sample_z_mean = np.zeros(model.z_mean.get_shape().as_list()) |
|
sample_z_stddev_log = np.zeros( |
|
model.z_stddev_log.get_shape().as_list()) |
|
sample_step = 0 |
|
|
|
|
|
def main(_): |
|
_Train() |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|