|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Segmentation results visualization on a given set of images. |
|
|
|
See model.py for more details and usage. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
import os.path |
|
import time |
|
import numpy as np |
|
from six.moves import range |
|
import tensorflow as tf |
|
from tensorflow.contrib import quantize as contrib_quantize |
|
from tensorflow.contrib import training as contrib_training |
|
from deeplab import common |
|
from deeplab import model |
|
from deeplab.datasets import data_generator |
|
from deeplab.utils import save_annotation |
|
|
|
flags = tf.app.flags |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string('master', '', 'BNS name of the tensorflow server') |
|
|
|
|
|
|
|
flags.DEFINE_string('vis_logdir', None, 'Where to write the event logs.') |
|
|
|
flags.DEFINE_string('checkpoint_dir', None, 'Directory of model checkpoints.') |
|
|
|
|
|
|
|
flags.DEFINE_integer('vis_batch_size', 1, |
|
'The number of images in each batch during evaluation.') |
|
|
|
flags.DEFINE_list('vis_crop_size', '513,513', |
|
'Crop size [height, width] for visualization.') |
|
|
|
flags.DEFINE_integer('eval_interval_secs', 60 * 5, |
|
'How often (in seconds) to run evaluation.') |
|
|
|
|
|
|
|
|
|
flags.DEFINE_multi_integer('atrous_rates', None, |
|
'Atrous rates for atrous spatial pyramid pooling.') |
|
|
|
flags.DEFINE_integer('output_stride', 16, |
|
'The ratio of input to output spatial resolution.') |
|
|
|
|
|
flags.DEFINE_multi_float('eval_scales', [1.0], |
|
'The scales to resize images for evaluation.') |
|
|
|
|
|
flags.DEFINE_bool('add_flipped_images', False, |
|
'Add flipped images for evaluation or not.') |
|
|
|
flags.DEFINE_integer( |
|
'quantize_delay_step', -1, |
|
'Steps to start quantized training. If < 0, will not quantize model.') |
|
|
|
|
|
|
|
flags.DEFINE_string('dataset', 'pascal_voc_seg', |
|
'Name of the segmentation dataset.') |
|
|
|
flags.DEFINE_string('vis_split', 'val', |
|
'Which split of the dataset used for visualizing results') |
|
|
|
flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.') |
|
|
|
flags.DEFINE_enum('colormap_type', 'pascal', ['pascal', 'cityscapes', 'ade20k'], |
|
'Visualization colormap type.') |
|
|
|
flags.DEFINE_boolean('also_save_raw_predictions', False, |
|
'Also save raw predictions.') |
|
|
|
flags.DEFINE_integer('max_number_of_iterations', 0, |
|
'Maximum number of visualization iterations. Will loop ' |
|
'indefinitely upon nonpositive values.') |
|
|
|
|
|
_SEMANTIC_PREDICTION_SAVE_FOLDER = 'segmentation_results' |
|
|
|
|
|
_RAW_SEMANTIC_PREDICTION_SAVE_FOLDER = 'raw_segmentation_results' |
|
|
|
|
|
_IMAGE_FORMAT = '%06d_image' |
|
|
|
|
|
_PREDICTION_FORMAT = '%06d_prediction' |
|
|
|
|
|
|
|
_CITYSCAPES_TRAIN_ID_TO_EVAL_ID = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, |
|
23, 24, 25, 26, 27, 28, 31, 32, 33] |
|
|
|
|
|
def _convert_train_id_to_eval_id(prediction, train_id_to_eval_id): |
|
"""Converts the predicted label for evaluation. |
|
|
|
There are cases where the training labels are not equal to the evaluation |
|
labels. This function is used to perform the conversion so that we could |
|
evaluate the results on the evaluation server. |
|
|
|
Args: |
|
prediction: Semantic segmentation prediction. |
|
train_id_to_eval_id: A list mapping from train id to evaluation id. |
|
|
|
Returns: |
|
Semantic segmentation prediction whose labels have been changed. |
|
""" |
|
converted_prediction = prediction.copy() |
|
for train_id, eval_id in enumerate(train_id_to_eval_id): |
|
converted_prediction[prediction == train_id] = eval_id |
|
|
|
return converted_prediction |
|
|
|
|
|
def _process_batch(sess, original_images, semantic_predictions, image_names, |
|
image_heights, image_widths, image_id_offset, save_dir, |
|
raw_save_dir, train_id_to_eval_id=None): |
|
"""Evaluates one single batch qualitatively. |
|
|
|
Args: |
|
sess: TensorFlow session. |
|
original_images: One batch of original images. |
|
semantic_predictions: One batch of semantic segmentation predictions. |
|
image_names: Image names. |
|
image_heights: Image heights. |
|
image_widths: Image widths. |
|
image_id_offset: Image id offset for indexing images. |
|
save_dir: The directory where the predictions will be saved. |
|
raw_save_dir: The directory where the raw predictions will be saved. |
|
train_id_to_eval_id: A list mapping from train id to eval id. |
|
""" |
|
(original_images, |
|
semantic_predictions, |
|
image_names, |
|
image_heights, |
|
image_widths) = sess.run([original_images, semantic_predictions, |
|
image_names, image_heights, image_widths]) |
|
|
|
num_image = semantic_predictions.shape[0] |
|
for i in range(num_image): |
|
image_height = np.squeeze(image_heights[i]) |
|
image_width = np.squeeze(image_widths[i]) |
|
original_image = np.squeeze(original_images[i]) |
|
semantic_prediction = np.squeeze(semantic_predictions[i]) |
|
crop_semantic_prediction = semantic_prediction[:image_height, :image_width] |
|
|
|
|
|
save_annotation.save_annotation( |
|
original_image, save_dir, _IMAGE_FORMAT % (image_id_offset + i), |
|
add_colormap=False) |
|
|
|
|
|
save_annotation.save_annotation( |
|
crop_semantic_prediction, save_dir, |
|
_PREDICTION_FORMAT % (image_id_offset + i), add_colormap=True, |
|
colormap_type=FLAGS.colormap_type) |
|
|
|
if FLAGS.also_save_raw_predictions: |
|
image_filename = os.path.basename(image_names[i]) |
|
|
|
if train_id_to_eval_id is not None: |
|
crop_semantic_prediction = _convert_train_id_to_eval_id( |
|
crop_semantic_prediction, |
|
train_id_to_eval_id) |
|
save_annotation.save_annotation( |
|
crop_semantic_prediction, raw_save_dir, image_filename, |
|
add_colormap=False) |
|
|
|
|
|
def main(unused_argv): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
|
|
|
|
dataset = data_generator.Dataset( |
|
dataset_name=FLAGS.dataset, |
|
split_name=FLAGS.vis_split, |
|
dataset_dir=FLAGS.dataset_dir, |
|
batch_size=FLAGS.vis_batch_size, |
|
crop_size=[int(sz) for sz in FLAGS.vis_crop_size], |
|
min_resize_value=FLAGS.min_resize_value, |
|
max_resize_value=FLAGS.max_resize_value, |
|
resize_factor=FLAGS.resize_factor, |
|
model_variant=FLAGS.model_variant, |
|
is_training=False, |
|
should_shuffle=False, |
|
should_repeat=False) |
|
|
|
train_id_to_eval_id = None |
|
if dataset.dataset_name == data_generator.get_cityscapes_dataset_name(): |
|
tf.logging.info('Cityscapes requires converting train_id to eval_id.') |
|
train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID |
|
|
|
|
|
tf.gfile.MakeDirs(FLAGS.vis_logdir) |
|
save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER) |
|
tf.gfile.MakeDirs(save_dir) |
|
raw_save_dir = os.path.join( |
|
FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER) |
|
tf.gfile.MakeDirs(raw_save_dir) |
|
|
|
tf.logging.info('Visualizing on %s set', FLAGS.vis_split) |
|
|
|
with tf.Graph().as_default(): |
|
samples = dataset.get_one_shot_iterator().get_next() |
|
|
|
model_options = common.ModelOptions( |
|
outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes}, |
|
crop_size=[int(sz) for sz in FLAGS.vis_crop_size], |
|
atrous_rates=FLAGS.atrous_rates, |
|
output_stride=FLAGS.output_stride) |
|
|
|
if tuple(FLAGS.eval_scales) == (1.0,): |
|
tf.logging.info('Performing single-scale test.') |
|
predictions = model.predict_labels( |
|
samples[common.IMAGE], |
|
model_options=model_options, |
|
image_pyramid=FLAGS.image_pyramid) |
|
else: |
|
tf.logging.info('Performing multi-scale test.') |
|
if FLAGS.quantize_delay_step >= 0: |
|
raise ValueError( |
|
'Quantize mode is not supported with multi-scale test.') |
|
predictions = model.predict_labels_multi_scale( |
|
samples[common.IMAGE], |
|
model_options=model_options, |
|
eval_scales=FLAGS.eval_scales, |
|
add_flipped_images=FLAGS.add_flipped_images) |
|
predictions = predictions[common.OUTPUT_TYPE] |
|
|
|
if FLAGS.min_resize_value and FLAGS.max_resize_value: |
|
|
|
|
|
assert FLAGS.vis_batch_size == 1 |
|
|
|
|
|
|
|
|
|
original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) |
|
original_image_shape = tf.shape(original_image) |
|
predictions = tf.slice( |
|
predictions, |
|
[0, 0, 0], |
|
[1, original_image_shape[0], original_image_shape[1]]) |
|
resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]), |
|
tf.squeeze(samples[common.WIDTH])]) |
|
predictions = tf.squeeze( |
|
tf.image.resize_images(tf.expand_dims(predictions, 3), |
|
resized_shape, |
|
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, |
|
align_corners=True), 3) |
|
|
|
tf.train.get_or_create_global_step() |
|
if FLAGS.quantize_delay_step >= 0: |
|
contrib_quantize.create_eval_graph() |
|
|
|
num_iteration = 0 |
|
max_num_iteration = FLAGS.max_number_of_iterations |
|
|
|
checkpoints_iterator = contrib_training.checkpoints_iterator( |
|
FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs) |
|
for checkpoint_path in checkpoints_iterator: |
|
num_iteration += 1 |
|
tf.logging.info( |
|
'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', |
|
time.gmtime())) |
|
tf.logging.info('Visualizing with model %s', checkpoint_path) |
|
|
|
scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer()) |
|
session_creator = tf.train.ChiefSessionCreator( |
|
scaffold=scaffold, |
|
master=FLAGS.master, |
|
checkpoint_filename_with_path=checkpoint_path) |
|
with tf.train.MonitoredSession( |
|
session_creator=session_creator, hooks=None) as sess: |
|
batch = 0 |
|
image_id_offset = 0 |
|
|
|
while not sess.should_stop(): |
|
tf.logging.info('Visualizing batch %d', batch + 1) |
|
_process_batch(sess=sess, |
|
original_images=samples[common.ORIGINAL_IMAGE], |
|
semantic_predictions=predictions, |
|
image_names=samples[common.IMAGE_NAME], |
|
image_heights=samples[common.HEIGHT], |
|
image_widths=samples[common.WIDTH], |
|
image_id_offset=image_id_offset, |
|
save_dir=save_dir, |
|
raw_save_dir=raw_save_dir, |
|
train_id_to_eval_id=train_id_to_eval_id) |
|
image_id_offset += FLAGS.vis_batch_size |
|
batch += 1 |
|
|
|
tf.logging.info( |
|
'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', |
|
time.gmtime())) |
|
if max_num_iteration > 0 and num_iteration >= max_num_iteration: |
|
break |
|
|
|
if __name__ == '__main__': |
|
flags.mark_flag_as_required('checkpoint_dir') |
|
flags.mark_flag_as_required('vis_logdir') |
|
flags.mark_flag_as_required('dataset_dir') |
|
tf.app.run() |
|
|