ASL-MoViNet-T5-translator / official /vision /dataloaders /tf_example_label_map_decoder.py
deanna-emery's picture
updates
93528c6
raw
history blame
2.6 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for object detection.
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import csv
# Import libraries
import tensorflow as tf, tf_keras
from official.vision.dataloaders import tf_example_decoder
class TfExampleDecoderLabelMap(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, label_map, include_mask=False, regenerate_source_id=False,
mask_binarize_threshold=None):
super(TfExampleDecoderLabelMap, self).__init__(
include_mask=include_mask, regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=mask_binarize_threshold)
self._keys_to_features.update({
'image/object/class/text': tf.io.VarLenFeature(tf.string),
})
name_to_id = self._process_label_map(label_map)
self._name_to_id_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(name_to_id.keys()), dtype=tf.string),
values=tf.constant(list(name_to_id.values()), dtype=tf.int64)),
default_value=-1)
def _process_label_map(self, label_map):
if label_map.endswith('.csv'):
name_to_id = self._process_csv(label_map)
else:
raise ValueError('The label map file is in incorrect format.')
return name_to_id
def _process_csv(self, label_map):
name_to_id = {}
with tf.io.gfile.GFile(label_map, 'r') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
if len(row) != 2:
raise ValueError('Each row of the csv label map file must be in '
'`id,name` format. length = {}'.format(len(row)))
id_index = int(row[0])
name = row[1]
name_to_id[name] = id_index
return name_to_id
def _decode_classes(self, parsed_tensors):
return self._name_to_id_table.lookup(
parsed_tensors['image/object/class/text'])