File size: 2,324 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TFDS detection decoders."""

import tensorflow as tf, tf_keras
from official.vision.dataloaders import decoder


class MSCOCODecoder(decoder.Decoder):
  """A tf.Example decoder for tfds coco datasets."""

  def decode(self, serialized_example):
    """Decode the serialized example.

    Args:
      serialized_example: a dictionary example produced by tfds.

    Returns:
      decoded_tensors: a dictionary of tensors with the following fields:
        - source_id: a string scalar tensor.
        - image: a uint8 tensor of shape [None, None, 3].
        - height: an integer scalar tensor.
        - width: an integer scalar tensor.
        - groundtruth_classes: a int64 tensor of shape [None].
        - groundtruth_is_crowd: a bool tensor of shape [None].
        - groundtruth_area: a float32 tensor of shape [None].
        - groundtruth_boxes: a float32 tensor of shape [None, 4].
    """

    decoded_tensors = {
        'source_id': tf.strings.as_string(serialized_example['image/id']),
        'image': serialized_example['image'],
        'height': tf.cast(tf.shape(serialized_example['image'])[0], tf.int64),
        'width': tf.cast(tf.shape(serialized_example['image'])[1], tf.int64),
        'groundtruth_classes': serialized_example['objects']['label'],
        'groundtruth_is_crowd': serialized_example['objects']['is_crowd'],
        'groundtruth_area': tf.cast(
            serialized_example['objects']['area'], tf.float32),
        'groundtruth_boxes': serialized_example['objects']['bbox'],
    }
    return decoded_tensors


TFDS_ID_TO_DECODER_MAP = {
    'coco/2017': MSCOCODecoder,
    'coco/2014': MSCOCODecoder,
    'coco': MSCOCODecoder,
    'scenic:objects365': MSCOCODecoder,
}