File size: 2,598 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tensorflow Example proto decoder for object detection.

A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import csv
# Import libraries
import tensorflow as tf, tf_keras

from official.vision.dataloaders import tf_example_decoder


class TfExampleDecoderLabelMap(tf_example_decoder.TfExampleDecoder):
  """Tensorflow Example proto decoder."""

  def __init__(self, label_map, include_mask=False, regenerate_source_id=False,
               mask_binarize_threshold=None):
    super(TfExampleDecoderLabelMap, self).__init__(
        include_mask=include_mask, regenerate_source_id=regenerate_source_id,
        mask_binarize_threshold=mask_binarize_threshold)
    self._keys_to_features.update({
        'image/object/class/text': tf.io.VarLenFeature(tf.string),
    })
    name_to_id = self._process_label_map(label_map)
    self._name_to_id_table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant(list(name_to_id.keys()), dtype=tf.string),
            values=tf.constant(list(name_to_id.values()), dtype=tf.int64)),
        default_value=-1)

  def _process_label_map(self, label_map):
    if label_map.endswith('.csv'):
      name_to_id = self._process_csv(label_map)
    else:
      raise ValueError('The label map file is in incorrect format.')
    return name_to_id

  def _process_csv(self, label_map):
    name_to_id = {}
    with tf.io.gfile.GFile(label_map, 'r') as f:
      reader = csv.reader(f, delimiter=',')
      for row in reader:
        if len(row) != 2:
          raise ValueError('Each row of the csv label map file must be in '
                           '`id,name` format. length = {}'.format(len(row)))
        id_index = int(row[0])
        name = row[1]
        name_to_id[name] = id_index
    return name_to_id

  def _decode_classes(self, parsed_tensors):
    return self._name_to_id_table.lookup(
        parsed_tensors['image/object/class/text'])