|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Helper functions for pretraining (rotator) as described in PTN paper.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
|
|
import numpy as np |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
import input_generator |
|
import losses |
|
import metrics |
|
import utils |
|
from nets import deeprotator_factory |
|
|
|
slim = tf.contrib.slim |
|
|
|
|
|
def _get_data_from_provider(inputs, batch_size, split_name): |
|
"""Returns dictionary of batch input data processed by tf.train.batch.""" |
|
images, masks = tf.train.batch( |
|
[inputs['image'], inputs['mask']], |
|
batch_size=batch_size, |
|
num_threads=8, |
|
capacity=8 * batch_size, |
|
name='batching_queues/%s' % (split_name)) |
|
|
|
outputs = dict() |
|
outputs['images'] = images |
|
outputs['masks'] = masks |
|
outputs['num_samples'] = inputs['num_samples'] |
|
|
|
return outputs |
|
|
|
|
|
def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size, |
|
is_training): |
|
"""Loads the given dataset and split.""" |
|
del image_size |
|
with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)): |
|
common_queue_min = 50 |
|
common_queue_capacity = 256 |
|
num_readers = 4 |
|
|
|
inputs = input_generator.get( |
|
dataset_dir, |
|
dataset_name, |
|
split_name, |
|
shuffle=is_training, |
|
num_readers=num_readers, |
|
common_queue_min=common_queue_min, |
|
common_queue_capacity=common_queue_capacity) |
|
|
|
return _get_data_from_provider(inputs, batch_size, split_name) |
|
|
|
|
|
def preprocess(raw_inputs, step_size): |
|
"""Selects the subset of viewpoints to train on.""" |
|
shp = raw_inputs['images'].get_shape().as_list() |
|
quantity = shp[0] |
|
num_views = shp[1] |
|
image_size = shp[2] |
|
del image_size |
|
|
|
batch_rot = np.zeros((quantity, 3), dtype=np.float32) |
|
inputs = dict() |
|
for n in xrange(step_size + 1): |
|
inputs['images_%d' % n] = [] |
|
inputs['masks_%d' % n] = [] |
|
|
|
for n in xrange(quantity): |
|
view_in = np.random.randint(0, num_views) |
|
rng_rot = np.random.randint(0, 2) |
|
if step_size == 1: |
|
rng_rot = np.random.randint(0, 3) |
|
|
|
delta = 0 |
|
if rng_rot == 0: |
|
delta = -1 |
|
batch_rot[n, 2] = 1 |
|
elif rng_rot == 1: |
|
delta = 1 |
|
batch_rot[n, 0] = 1 |
|
else: |
|
delta = 0 |
|
batch_rot[n, 1] = 1 |
|
|
|
inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :]) |
|
inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :]) |
|
|
|
view_out = view_in |
|
for k in xrange(1, step_size + 1): |
|
view_out += delta |
|
if view_out >= num_views: |
|
view_out = 0 |
|
if view_out < 0: |
|
view_out = num_views - 1 |
|
|
|
inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :]) |
|
inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :]) |
|
|
|
for n in xrange(step_size + 1): |
|
inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n]) |
|
inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n]) |
|
|
|
inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32) |
|
return inputs |
|
|
|
|
|
def get_init_fn(scopes, params): |
|
"""Initialization assignment operator function used while training.""" |
|
if not params.init_model: |
|
return None |
|
|
|
is_trainable = lambda x: x in tf.trainable_variables() |
|
var_list = [] |
|
for scope in scopes: |
|
var_list.extend( |
|
filter(is_trainable, tf.contrib.framework.get_model_variables(scope))) |
|
|
|
init_assign_op, init_feed_dict = slim.assign_from_checkpoint( |
|
params.init_model, var_list) |
|
|
|
def init_assign_function(sess): |
|
sess.run(init_assign_op, init_feed_dict) |
|
|
|
return init_assign_function |
|
|
|
|
|
def get_model_fn(params, is_training, reuse=False): |
|
return deeprotator_factory.get(params, is_training, reuse) |
|
|
|
|
|
def get_regularization_loss(scopes, params): |
|
return losses.regularization_loss(scopes, params) |
|
|
|
|
|
def get_loss(inputs, outputs, params): |
|
"""Computes the rotator loss.""" |
|
g_loss = tf.zeros(dtype=tf.float32, shape=[]) |
|
|
|
if hasattr(params, 'image_weight'): |
|
g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size, |
|
params.image_weight) |
|
|
|
if hasattr(params, 'mask_weight'): |
|
g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size, |
|
params.mask_weight) |
|
|
|
slim.summaries.add_scalar_summary( |
|
g_loss, 'rotator_loss', prefix='losses') |
|
|
|
return g_loss |
|
|
|
|
|
def get_train_op_for_scope(loss, optimizer, scopes, params): |
|
"""Train operation function for the given scope used file training.""" |
|
is_trainable = lambda x: x in tf.trainable_variables() |
|
|
|
var_list = [] |
|
update_ops = [] |
|
|
|
for scope in scopes: |
|
var_list.extend( |
|
filter(is_trainable, tf.contrib.framework.get_model_variables(scope))) |
|
update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) |
|
|
|
return slim.learning.create_train_op( |
|
loss, |
|
optimizer, |
|
update_ops=update_ops, |
|
variables_to_train=var_list, |
|
clip_gradient_norm=params.clip_gradient_norm) |
|
|
|
|
|
def get_metrics(inputs, outputs, params): |
|
"""Aggregate the metrics for rotator model. |
|
|
|
Args: |
|
inputs: Input dictionary of the rotator model. |
|
outputs: Output dictionary returned by the rotator model. |
|
params: Hyperparameters of the rotator model. |
|
|
|
Returns: |
|
names_to_values: metrics->values (dict). |
|
names_to_updates: metrics->ops (dict). |
|
""" |
|
names_to_values = dict() |
|
names_to_updates = dict() |
|
|
|
tmp_values, tmp_updates = metrics.add_image_pred_metrics( |
|
inputs, outputs, params.num_views, 3*params.image_size**2) |
|
names_to_values.update(tmp_values) |
|
names_to_updates.update(tmp_updates) |
|
|
|
tmp_values, tmp_updates = metrics.add_mask_pred_metrics( |
|
inputs, outputs, params.num_views, params.image_size**2) |
|
names_to_values.update(tmp_values) |
|
names_to_updates.update(tmp_updates) |
|
|
|
for name, value in names_to_values.iteritems(): |
|
slim.summaries.add_scalar_summary( |
|
value, name, prefix='eval', print_summary=True) |
|
|
|
return names_to_values, names_to_updates |
|
|
|
|
|
def write_disk_grid(global_step, summary_freq, log_dir, input_images, |
|
output_images, pred_images, pred_masks): |
|
"""Function called by TF to save the prediction periodically.""" |
|
|
|
def write_grid(grid, global_step): |
|
"""Native python function to call for writing images to files.""" |
|
if global_step % summary_freq == 0: |
|
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step)) |
|
utils.save_image(grid, img_path) |
|
return 0 |
|
|
|
grid = _build_image_grid(input_images, output_images, pred_images, pred_masks) |
|
slim.summaries.add_image_summary( |
|
tf.expand_dims(grid, axis=0), name='grid_vis') |
|
save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64], |
|
'write_grid')[0] |
|
return save_op |
|
|
|
|
|
def _build_image_grid(input_images, output_images, pred_images, pred_masks): |
|
"""Builds a grid image by concatenating the input images.""" |
|
quantity = input_images.get_shape().as_list()[0] |
|
|
|
for row in xrange(int(quantity / 4)): |
|
for col in xrange(4): |
|
index = row * 4 + col |
|
input_img_ = input_images[index, :, :, :] |
|
output_img_ = output_images[index, :, :, :] |
|
pred_img_ = pred_images[index, :, :, :] |
|
pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3]) |
|
if col == 0: |
|
tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_], |
|
1) |
|
else: |
|
tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_], |
|
1) |
|
if row == 0: |
|
out_grid = tmp_ |
|
else: |
|
out_grid = tf.concat([out_grid, tmp_], 0) |
|
|
|
return out_grid |
|
|