|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contains common utility functions and classes for building dataset. |
|
|
|
This script contains utility functions and classes to converts dataset to |
|
TFRecord file format with Example protos. |
|
|
|
The Example proto contains the following fields: |
|
|
|
image/encoded: encoded image content. |
|
image/filename: image filename. |
|
image/format: image file format. |
|
image/height: image height. |
|
image/width: image width. |
|
image/channels: image channels. |
|
image/segmentation/class/encoded: encoded semantic segmentation content. |
|
image/segmentation/class/format: semantic segmentation file format. |
|
""" |
|
import collections |
|
import six |
|
import tensorflow as tf |
|
|
|
FLAGS = tf.app.flags.FLAGS |
|
|
|
tf.app.flags.DEFINE_enum('image_format', 'png', ['jpg', 'jpeg', 'png'], |
|
'Image format.') |
|
|
|
tf.app.flags.DEFINE_enum('label_format', 'png', ['png'], |
|
'Segmentation label format.') |
|
|
|
|
|
_IMAGE_FORMAT_MAP = { |
|
'jpg': 'jpeg', |
|
'jpeg': 'jpeg', |
|
'png': 'png', |
|
} |
|
|
|
|
|
class ImageReader(object): |
|
"""Helper class that provides TensorFlow image coding utilities.""" |
|
|
|
def __init__(self, image_format='jpeg', channels=3): |
|
"""Class constructor. |
|
|
|
Args: |
|
image_format: Image format. Only 'jpeg', 'jpg', or 'png' are supported. |
|
channels: Image channels. |
|
""" |
|
with tf.Graph().as_default(): |
|
self._decode_data = tf.placeholder(dtype=tf.string) |
|
self._image_format = image_format |
|
self._session = tf.Session() |
|
if self._image_format in ('jpeg', 'jpg'): |
|
self._decode = tf.image.decode_jpeg(self._decode_data, |
|
channels=channels) |
|
elif self._image_format == 'png': |
|
self._decode = tf.image.decode_png(self._decode_data, |
|
channels=channels) |
|
|
|
def read_image_dims(self, image_data): |
|
"""Reads the image dimensions. |
|
|
|
Args: |
|
image_data: string of image data. |
|
|
|
Returns: |
|
image_height and image_width. |
|
""" |
|
image = self.decode_image(image_data) |
|
return image.shape[:2] |
|
|
|
def decode_image(self, image_data): |
|
"""Decodes the image data string. |
|
|
|
Args: |
|
image_data: string of image data. |
|
|
|
Returns: |
|
Decoded image data. |
|
|
|
Raises: |
|
ValueError: Value of image channels not supported. |
|
""" |
|
image = self._session.run(self._decode, |
|
feed_dict={self._decode_data: image_data}) |
|
if len(image.shape) != 3 or image.shape[2] not in (1, 3): |
|
raise ValueError('The image channels not supported.') |
|
|
|
return image |
|
|
|
|
|
def _int64_list_feature(values): |
|
"""Returns a TF-Feature of int64_list. |
|
|
|
Args: |
|
values: A scalar or list of values. |
|
|
|
Returns: |
|
A TF-Feature. |
|
""" |
|
if not isinstance(values, collections.Iterable): |
|
values = [values] |
|
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) |
|
|
|
|
|
def _bytes_list_feature(values): |
|
"""Returns a TF-Feature of bytes. |
|
|
|
Args: |
|
values: A string. |
|
|
|
Returns: |
|
A TF-Feature. |
|
""" |
|
def norm2bytes(value): |
|
return value.encode() if isinstance(value, str) and six.PY3 else value |
|
|
|
return tf.train.Feature( |
|
bytes_list=tf.train.BytesList(value=[norm2bytes(values)])) |
|
|
|
|
|
def image_seg_to_tfexample(image_data, filename, height, width, seg_data): |
|
"""Converts one image/segmentation pair to tf example. |
|
|
|
Args: |
|
image_data: string of image data. |
|
filename: image filename. |
|
height: image height. |
|
width: image width. |
|
seg_data: string of semantic segmentation data. |
|
|
|
Returns: |
|
tf example of one image/segmentation pair. |
|
""" |
|
return tf.train.Example(features=tf.train.Features(feature={ |
|
'image/encoded': _bytes_list_feature(image_data), |
|
'image/filename': _bytes_list_feature(filename), |
|
'image/format': _bytes_list_feature( |
|
_IMAGE_FORMAT_MAP[FLAGS.image_format]), |
|
'image/height': _int64_list_feature(height), |
|
'image/width': _int64_list_feature(width), |
|
'image/channels': _int64_list_feature(3), |
|
'image/segmentation/class/encoded': ( |
|
_bytes_list_feature(seg_data)), |
|
'image/segmentation/class/format': _bytes_list_feature( |
|
FLAGS.label_format), |
|
})) |
|
|