File size: 3,404 Bytes
18ddfe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FSNS datasets module."""
import collections
import os
import tensorflow as tf
from tensorflow.contrib import slim
from datasets import fsns
from datasets import unittest_utils
FLAGS = tf.flags.FLAGS
def get_test_split():
config = fsns.DEFAULT_CONFIG.copy()
config['splits'] = {'test': {'size': 5, 'pattern': 'fsns-00000-of-00001'}}
return fsns.get_split('test', dataset_dir(), config)
def dataset_dir():
return os.path.join(os.path.dirname(__file__), 'testdata/fsns')
class FsnsTest(tf.test.TestCase):
def test_decodes_example_proto(self):
expected_label = range(37)
expected_image, encoded = unittest_utils.create_random_image(
'PNG', shape=(150, 600, 3))
serialized = unittest_utils.create_serialized_example({
'image/encoded': [encoded],
'image/format': [b'PNG'],
'image/class':
expected_label,
'image/unpadded_class':
range(10),
'image/text': [b'Raw text'],
'image/orig_width': [150],
'image/width': [600]
})
decoder = fsns.get_split('train', dataset_dir()).decoder
with self.test_session() as sess:
data_tuple = collections.namedtuple('DecodedData', decoder.list_items())
data = sess.run(data_tuple(*decoder.decode(serialized)))
self.assertAllEqual(expected_image, data.image)
self.assertAllEqual(expected_label, data.label)
self.assertEqual([b'Raw text'], data.text)
self.assertEqual([1], data.num_of_views)
def test_label_has_shape_defined(self):
serialized = 'fake'
decoder = fsns.get_split('train', dataset_dir()).decoder
[label_tf] = decoder.decode(serialized, ['label'])
self.assertEqual(label_tf.get_shape().dims[0], 37)
def test_dataset_tuple_has_all_extra_attributes(self):
dataset = fsns.get_split('train', dataset_dir())
self.assertTrue(dataset.charset)
self.assertTrue(dataset.num_char_classes)
self.assertTrue(dataset.num_of_views)
self.assertTrue(dataset.max_sequence_length)
self.assertTrue(dataset.null_code)
def test_can_use_the_test_data(self):
batch_size = 1
dataset = get_test_split()
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=True,
common_queue_capacity=2 * batch_size,
common_queue_min=batch_size)
image_tf, label_tf = provider.get(['image', 'label'])
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
with slim.queues.QueueRunners(sess):
image_np, label_np = sess.run([image_tf, label_tf])
self.assertEqual((150, 600, 3), image_np.shape)
self.assertEqual((37, ), label_np.shape)
if __name__ == '__main__':
tf.test.main()
|