deanna-emery's picture
updates
93528c6
raw
history blame
8.65 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for object detection.
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow as tf, tf_keras
from official.vision.dataloaders import decoder
def _generate_source_id(image_bytes):
# Hashing using 22 bits since float32 has only 23 mantissa bits.
return tf.strings.as_string(
tf.strings.to_hash_bucket_fast(image_bytes, 2 ** 22 - 1))
class TfExampleDecoder(decoder.Decoder):
"""Tensorflow Example proto decoder."""
def __init__(
self,
include_mask=False,
regenerate_source_id=False,
mask_binarize_threshold=None,
attribute_names=None,
):
self._include_mask = include_mask
self._regenerate_source_id = regenerate_source_id
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/height': tf.io.FixedLenFeature((), tf.int64, -1),
'image/width': tf.io.FixedLenFeature((), tf.int64, -1),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/area': tf.io.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.io.VarLenFeature(tf.int64),
}
attribute_names = attribute_names or []
for attr_name in attribute_names:
self._keys_to_features[f'image/object/attribute/{attr_name}'] = (
tf.io.VarLenFeature(tf.int64)
)
self._attribute_names = attribute_names
self._mask_binarize_threshold = mask_binarize_threshold
if include_mask:
self._keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string),
})
if not regenerate_source_id:
self._keys_to_features.update({
'image/source_id': tf.io.FixedLenFeature((), tf.string),
})
def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape."""
image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3)
image.set_shape([None, None, 3])
return image
def _decode_boxes(self, parsed_tensors):
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
return tf.stack([ymin, xmin, ymax, xmax], axis=-1)
def _decode_classes(self, parsed_tensors):
return parsed_tensors['image/object/class/label']
def _decode_attributes(self, parsed_tensors):
attribute_dict = dict()
for attr_name in self._attribute_names:
attr_array = parsed_tensors[f'image/object/attribute/{attr_name}']
# TODO(b/269654135): Support decoding of fully 2D attributes.
attribute_dict[attr_name] = tf.expand_dims(attr_array, -1)
return attribute_dict
def _decode_areas(self, parsed_tensors):
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
height = tf.cast(parsed_tensors['image/height'], dtype=tf.float32)
width = tf.cast(parsed_tensors['image/width'], dtype=tf.float32)
return tf.cond(
tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0),
lambda: parsed_tensors['image/object/area'],
lambda: (xmax - xmin) * (ymax - ymin) * height * width)
def _decode_masks(self, parsed_tensors):
"""Decode a set of PNG masks to the tf.float32 tensors."""
def _decode_png_mask(png_bytes):
mask = tf.squeeze(
tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
mask = tf.cast(mask, dtype=tf.float32)
mask.set_shape([None, None])
return mask
height = parsed_tensors['image/height']
width = parsed_tensors['image/width']
masks = parsed_tensors['image/object/mask']
return tf.cond(
pred=tf.greater(tf.size(input=masks), 0),
true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
def decode(self, serialized_example):
"""Decode the serialized example.
Args:
serialized_example: a single serialized tf.Example string.
Returns:
decoded_tensors: a dictionary of tensors with the following fields:
- source_id: a string scalar tensor.
- image: a uint8 tensor of shape [None, None, 3].
- height: an integer scalar tensor.
- width: an integer scalar tensor.
- groundtruth_classes: a int64 tensor of shape [None].
- groundtruth_is_crowd: a bool tensor of shape [None].
- groundtruth_area: a float32 tensor of shape [None].
- groundtruth_boxes: a float32 tensor of shape [None, 4].
- groundtruth_instance_masks: a float32 tensor of shape
[None, None, None].
- groundtruth_instance_masks_png: a string tensor of shape [None].
"""
parsed_tensors = tf.io.parse_single_example(
serialized=serialized_example, features=self._keys_to_features)
for k in parsed_tensors:
if isinstance(parsed_tensors[k], tf.SparseTensor):
if parsed_tensors[k].dtype == tf.string:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value='')
else:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value=0)
if self._regenerate_source_id:
source_id = _generate_source_id(parsed_tensors['image/encoded'])
else:
source_id = tf.cond(
tf.greater(tf.strings.length(parsed_tensors['image/source_id']), 0),
lambda: parsed_tensors['image/source_id'],
lambda: _generate_source_id(parsed_tensors['image/encoded']))
image = self._decode_image(parsed_tensors)
boxes = self._decode_boxes(parsed_tensors)
classes = self._decode_classes(parsed_tensors)
areas = self._decode_areas(parsed_tensors)
attributes = self._decode_attributes(parsed_tensors)
decode_image_shape = tf.logical_or(
tf.equal(parsed_tensors['image/height'], -1),
tf.equal(parsed_tensors['image/width'], -1))
image_shape = tf.cast(tf.shape(image), dtype=tf.int64)
parsed_tensors['image/height'] = tf.where(decode_image_shape,
image_shape[0],
parsed_tensors['image/height'])
parsed_tensors['image/width'] = tf.where(decode_image_shape, image_shape[1],
parsed_tensors['image/width'])
is_crowds = tf.cond(
tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
lambda: tf.zeros_like(classes, dtype=tf.bool))
if self._include_mask:
masks = self._decode_masks(parsed_tensors)
if self._mask_binarize_threshold is not None:
masks = tf.cast(masks > self._mask_binarize_threshold, tf.float32)
decoded_tensors = {
'source_id': source_id,
'image': image,
'height': parsed_tensors['image/height'],
'width': parsed_tensors['image/width'],
'groundtruth_classes': classes,
'groundtruth_is_crowd': is_crowds,
'groundtruth_area': areas,
'groundtruth_boxes': boxes,
}
if self._attribute_names:
decoded_tensors.update({'groundtruth_attributes': attributes})
if self._include_mask:
decoded_tensors.update({
'groundtruth_instance_masks': masks,
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
})
return decoded_tensors