Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""TFDS detection decoders.""" | |
import tensorflow as tf, tf_keras | |
from official.vision.dataloaders import decoder | |
class MSCOCODecoder(decoder.Decoder): | |
"""A tf.Example decoder for tfds coco datasets.""" | |
def decode(self, serialized_example): | |
"""Decode the serialized example. | |
Args: | |
serialized_example: a dictionary example produced by tfds. | |
Returns: | |
decoded_tensors: a dictionary of tensors with the following fields: | |
- source_id: a string scalar tensor. | |
- image: a uint8 tensor of shape [None, None, 3]. | |
- height: an integer scalar tensor. | |
- width: an integer scalar tensor. | |
- groundtruth_classes: a int64 tensor of shape [None]. | |
- groundtruth_is_crowd: a bool tensor of shape [None]. | |
- groundtruth_area: a float32 tensor of shape [None]. | |
- groundtruth_boxes: a float32 tensor of shape [None, 4]. | |
""" | |
decoded_tensors = { | |
'source_id': tf.strings.as_string(serialized_example['image/id']), | |
'image': serialized_example['image'], | |
'height': tf.cast(tf.shape(serialized_example['image'])[0], tf.int64), | |
'width': tf.cast(tf.shape(serialized_example['image'])[1], tf.int64), | |
'groundtruth_classes': serialized_example['objects']['label'], | |
'groundtruth_is_crowd': serialized_example['objects']['is_crowd'], | |
'groundtruth_area': tf.cast( | |
serialized_example['objects']['area'], tf.float32), | |
'groundtruth_boxes': serialized_example['objects']['bbox'], | |
} | |
return decoded_tensors | |
TFDS_ID_TO_DECODER_MAP = { | |
'coco/2017': MSCOCODecoder, | |
'coco/2014': MSCOCODecoder, | |
'coco': MSCOCODecoder, | |
'scenic:objects365': MSCOCODecoder, | |
} | |