Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tensorflow Example proto decoder for object detection. | |
A decoder to decode string tensors containing serialized tensorflow.Example | |
protos for object detection. | |
""" | |
import tensorflow as tf, tf_keras | |
class TfExampleDecoder(object): | |
"""Tensorflow Example proto decoder.""" | |
def __init__(self, include_mask=False): | |
self._include_mask = include_mask | |
self._keys_to_features = { | |
'image/encoded': | |
tf.io.FixedLenFeature((), tf.string), | |
'image/source_id': | |
tf.io.FixedLenFeature((), tf.string), | |
'image/height': | |
tf.io.FixedLenFeature((), tf.int64), | |
'image/width': | |
tf.io.FixedLenFeature((), tf.int64), | |
'image/object/bbox/xmin': | |
tf.io.VarLenFeature(tf.float32), | |
'image/object/bbox/xmax': | |
tf.io.VarLenFeature(tf.float32), | |
'image/object/bbox/ymin': | |
tf.io.VarLenFeature(tf.float32), | |
'image/object/bbox/ymax': | |
tf.io.VarLenFeature(tf.float32), | |
'image/object/class/label': | |
tf.io.VarLenFeature(tf.int64), | |
'image/object/area': | |
tf.io.VarLenFeature(tf.float32), | |
'image/object/is_crowd': | |
tf.io.VarLenFeature(tf.int64), | |
} | |
if include_mask: | |
self._keys_to_features.update({ | |
'image/object/mask': | |
tf.io.VarLenFeature(tf.string), | |
}) | |
def _decode_image(self, parsed_tensors): | |
"""Decodes the image and set its static shape.""" | |
image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3) | |
image.set_shape([None, None, 3]) | |
return image | |
def _decode_boxes(self, parsed_tensors): | |
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax].""" | |
xmin = parsed_tensors['image/object/bbox/xmin'] | |
xmax = parsed_tensors['image/object/bbox/xmax'] | |
ymin = parsed_tensors['image/object/bbox/ymin'] | |
ymax = parsed_tensors['image/object/bbox/ymax'] | |
return tf.stack([ymin, xmin, ymax, xmax], axis=-1) | |
def _decode_masks(self, parsed_tensors): | |
"""Decode a set of PNG masks to the tf.float32 tensors.""" | |
def _decode_png_mask(png_bytes): | |
mask = tf.squeeze( | |
tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1) | |
mask = tf.cast(mask, dtype=tf.float32) | |
mask.set_shape([None, None]) | |
return mask | |
height = parsed_tensors['image/height'] | |
width = parsed_tensors['image/width'] | |
masks = parsed_tensors['image/object/mask'] | |
return tf.cond( | |
pred=tf.greater(tf.size(input=masks), 0), | |
true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32), | |
false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32)) | |
def _decode_areas(self, parsed_tensors): | |
xmin = parsed_tensors['image/object/bbox/xmin'] | |
xmax = parsed_tensors['image/object/bbox/xmax'] | |
ymin = parsed_tensors['image/object/bbox/ymin'] | |
ymax = parsed_tensors['image/object/bbox/ymax'] | |
return tf.cond( | |
tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0), | |
lambda: parsed_tensors['image/object/area'], | |
lambda: (xmax - xmin) * (ymax - ymin)) | |
def decode(self, serialized_example): | |
"""Decode the serialized example. | |
Args: | |
serialized_example: a single serialized tf.Example string. | |
Returns: | |
decoded_tensors: a dictionary of tensors with the following fields: | |
- image: a uint8 tensor of shape [None, None, 3]. | |
- source_id: a string scalar tensor. | |
- height: an integer scalar tensor. | |
- width: an integer scalar tensor. | |
- groundtruth_classes: a int64 tensor of shape [None]. | |
- groundtruth_is_crowd: a bool tensor of shape [None]. | |
- groundtruth_area: a float32 tensor of shape [None]. | |
- groundtruth_boxes: a float32 tensor of shape [None, 4]. | |
- groundtruth_instance_masks: a float32 tensor of shape | |
[None, None, None]. | |
- groundtruth_instance_masks_png: a string tensor of shape [None]. | |
""" | |
parsed_tensors = tf.io.parse_single_example( | |
serialized=serialized_example, features=self._keys_to_features) | |
for k in parsed_tensors: | |
if isinstance(parsed_tensors[k], tf.SparseTensor): | |
if parsed_tensors[k].dtype == tf.string: | |
parsed_tensors[k] = tf.sparse.to_dense( | |
parsed_tensors[k], default_value='') | |
else: | |
parsed_tensors[k] = tf.sparse.to_dense( | |
parsed_tensors[k], default_value=0) | |
image = self._decode_image(parsed_tensors) | |
boxes = self._decode_boxes(parsed_tensors) | |
areas = self._decode_areas(parsed_tensors) | |
is_crowds = tf.cond( | |
tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0), | |
lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool), | |
lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long | |
if self._include_mask: | |
masks = self._decode_masks(parsed_tensors) | |
decoded_tensors = { | |
'image': image, | |
'source_id': parsed_tensors['image/source_id'], | |
'height': parsed_tensors['image/height'], | |
'width': parsed_tensors['image/width'], | |
'groundtruth_classes': parsed_tensors['image/object/class/label'], | |
'groundtruth_is_crowd': is_crowds, | |
'groundtruth_area': areas, | |
'groundtruth_boxes': boxes, | |
} | |
if self._include_mask: | |
decoded_tensors.update({ | |
'groundtruth_instance_masks': masks, | |
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'], | |
}) | |
return decoded_tensors | |