|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contains training plan for the Im2vox model.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from tensorflow import app |
|
|
|
import model_ptn |
|
|
|
flags = tf.app.flags |
|
slim = tf.contrib.slim |
|
|
|
flags.DEFINE_string('inp_dir', |
|
'', |
|
'Directory path containing the input data (tfrecords).') |
|
flags.DEFINE_string( |
|
'dataset_name', 'shapenet_chair', |
|
'Dataset name that is to be used for training and evaluation.') |
|
flags.DEFINE_integer('z_dim', 512, '') |
|
flags.DEFINE_integer('f_dim', 64, '') |
|
flags.DEFINE_integer('fc_dim', 1024, '') |
|
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.') |
|
flags.DEFINE_integer('image_size', 64, |
|
'Input images dimension (pixels) - width & height.') |
|
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.') |
|
flags.DEFINE_integer('step_size', 24, 'Steps to take in rotation to fetch viewpoints.') |
|
flags.DEFINE_integer('batch_size', 6, 'Batch size while training.') |
|
flags.DEFINE_float('focal_length', 0.866, 'Focal length parameter used in perspective projection.') |
|
flags.DEFINE_float('focal_range', 1.732, 'Focal length parameter used in perspective projection.') |
|
flags.DEFINE_string('encoder_name', 'ptn_encoder', |
|
'Name of the encoder network being used.') |
|
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder', |
|
'Name of the decoder network being used.') |
|
flags.DEFINE_string('projector_name', 'perspective_projector', |
|
'Name of the projector network being used.') |
|
|
|
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/', |
|
'Directory path for saving trained models and other data.') |
|
flags.DEFINE_string('model_name', 'ptn_finetune', |
|
'Name of the model used in naming the TF job. Must be different for each run.') |
|
flags.DEFINE_string('init_model', None, |
|
'Checkpoint path of the model to initialize with.') |
|
flags.DEFINE_integer('save_every', 1000, |
|
'Average period of steps after which we save a model.') |
|
|
|
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.') |
|
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.') |
|
flags.DEFINE_float('viewpoint_weight', 1, 'Weighting factor for viewpoint loss.') |
|
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.') |
|
flags.DEFINE_float('weight_decay', 0.001, 'Weight decay parameter while training.') |
|
flags.DEFINE_float('clip_gradient_norm', 0, 'Gradient clim norm, leave 0 if no gradient clipping.') |
|
flags.DEFINE_integer('max_number_of_steps', 10000, 'Maximum number of steps for training.') |
|
|
|
flags.DEFINE_integer('save_summaries_secs', 15, 'Seconds interval for dumping TF summaries.') |
|
flags.DEFINE_integer('save_interval_secs', 60 * 5, 'Seconds interval to save models.') |
|
|
|
|
|
flags.DEFINE_string('master', '', 'The address of the tensorflow master') |
|
flags.DEFINE_bool('sync_replicas', False, 'Whether to sync gradients between replicas for optimizer.') |
|
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas (train tasks).') |
|
flags.DEFINE_integer('backup_workers', 0, 'Number of backup workers.') |
|
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps tasks.') |
|
flags.DEFINE_integer('task', 0, |
|
'Task identifier flag to be set for each task running in distributed manner. Task number 0 ' |
|
'will be chosen as the chief.') |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def main(_): |
|
train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') |
|
save_image_dir = os.path.join(train_dir, 'images') |
|
if not os.path.exists(train_dir): |
|
os.makedirs(train_dir) |
|
if not os.path.exists(save_image_dir): |
|
os.makedirs(save_image_dir) |
|
|
|
g = tf.Graph() |
|
with g.as_default(): |
|
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): |
|
global_step = slim.get_or_create_global_step() |
|
|
|
|
|
|
|
model = model_ptn.model_PTN(FLAGS) |
|
|
|
|
|
|
|
train_data = model.get_inputs( |
|
FLAGS.inp_dir, |
|
FLAGS.dataset_name, |
|
'train', |
|
FLAGS.batch_size, |
|
FLAGS.image_size, |
|
FLAGS.vox_size, |
|
is_training=True) |
|
inputs = model.preprocess(train_data, FLAGS.step_size) |
|
|
|
|
|
|
|
model_fn = model.get_model_fn( |
|
is_training=True, reuse=False, run_projection=True) |
|
outputs = model_fn(inputs) |
|
|
|
|
|
|
|
if FLAGS.init_model: |
|
train_scopes = ['decoder'] |
|
init_scopes = ['encoder'] |
|
else: |
|
train_scopes = ['encoder', 'decoder'] |
|
|
|
|
|
|
|
|
|
task_loss = model.get_loss(inputs, outputs) |
|
|
|
regularization_loss = model.get_regularization_loss(train_scopes) |
|
loss = task_loss + regularization_loss |
|
|
|
|
|
|
|
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) |
|
if FLAGS.sync_replicas: |
|
optimizer = tf.train.SyncReplicasOptimizer( |
|
optimizer, |
|
replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers, |
|
total_num_replicas=FLAGS.worker_replicas) |
|
|
|
|
|
|
|
|
|
train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes) |
|
|
|
|
|
|
|
saver = tf.train.Saver(max_to_keep=np.minimum(5, |
|
FLAGS.worker_replicas + 1)) |
|
|
|
if FLAGS.task == 0: |
|
params = FLAGS |
|
params.batch_size = params.num_views |
|
params.step_size = 1 |
|
model.set_params(params) |
|
val_data = model.get_inputs( |
|
params.inp_dir, |
|
params.dataset_name, |
|
'val', |
|
params.batch_size, |
|
params.image_size, |
|
params.vox_size, |
|
is_training=False) |
|
val_inputs = model.preprocess(val_data, params.step_size) |
|
|
|
reused_model_fn = model.get_model_fn(is_training=False, reuse=True) |
|
val_outputs = reused_model_fn(val_inputs) |
|
|
|
with tf.device(tf.DeviceSpec(device_type='CPU')): |
|
vis_input_images = val_inputs['images_1'] * 255.0 |
|
vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0 |
|
vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0 |
|
|
|
vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3) |
|
vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3) |
|
|
|
new_size = [FLAGS.image_size] * 2 |
|
vis_gt_projs = tf.image.resize_nearest_neighbor( |
|
vis_gt_projs, new_size) |
|
vis_pred_projs = tf.image.resize_nearest_neighbor( |
|
vis_pred_projs, new_size) |
|
|
|
|
|
|
|
|
|
write_disk_op = model.write_disk_grid( |
|
global_step=global_step, |
|
log_dir=save_image_dir, |
|
input_images=vis_input_images, |
|
gt_projs=vis_gt_projs, |
|
pred_projs=vis_pred_projs, |
|
input_voxels=val_inputs['voxels'], |
|
output_voxels=val_outputs['voxels_1']) |
|
with tf.control_dependencies([write_disk_op]): |
|
train_op = tf.identity(train_op) |
|
|
|
|
|
|
|
|
|
if FLAGS.init_model: |
|
init_fn = model.get_init_fn(init_scopes) |
|
else: |
|
init_fn = None |
|
|
|
|
|
|
|
|
|
slim.learning.train( |
|
train_op=train_op, |
|
logdir=train_dir, |
|
init_fn=init_fn, |
|
master=FLAGS.master, |
|
is_chief=(FLAGS.task == 0), |
|
number_of_steps=FLAGS.max_number_of_steps, |
|
saver=saver, |
|
save_summaries_secs=FLAGS.save_summaries_secs, |
|
save_interval_secs=FLAGS.save_interval_secs) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run() |
|
|