Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for decoder factory functions.""" | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.distribute import combinations | |
from official.vision import configs | |
from official.vision.configs import decoders as decoders_cfg | |
from official.vision.modeling import decoders | |
from official.vision.modeling.decoders import factory | |
class FactoryTest(tf.test.TestCase, parameterized.TestCase): | |
def test_fpn_decoder_creation(self, num_filters, use_separable_conv): | |
"""Test creation of FPN decoder.""" | |
min_level = 3 | |
max_level = 7 | |
input_specs = {} | |
for level in range(min_level, max_level): | |
input_specs[str(level)] = tf.TensorShape( | |
[1, 128 // (2**level), 128 // (2**level), 3]) | |
network = decoders.FPN( | |
input_specs=input_specs, | |
num_filters=num_filters, | |
use_separable_conv=use_separable_conv, | |
use_sync_bn=True) | |
model_config = configs.retinanet.RetinaNet() | |
model_config.min_level = min_level | |
model_config.max_level = max_level | |
model_config.num_classes = 10 | |
model_config.input_size = [None, None, 3] | |
model_config.decoder = decoders_cfg.Decoder( | |
type='fpn', | |
fpn=decoders_cfg.FPN( | |
num_filters=num_filters, use_separable_conv=use_separable_conv)) | |
factory_network = factory.build_decoder( | |
input_specs=input_specs, model_config=model_config) | |
network_config = network.get_config() | |
factory_network_config = factory_network.get_config() | |
self.assertEqual(network_config, factory_network_config) | |
def test_nasfpn_decoder_creation(self, num_filters, num_repeats, | |
use_separable_conv): | |
"""Test creation of NASFPN decoder.""" | |
min_level = 3 | |
max_level = 7 | |
input_specs = {} | |
for level in range(min_level, max_level): | |
input_specs[str(level)] = tf.TensorShape( | |
[1, 128 // (2**level), 128 // (2**level), 3]) | |
network = decoders.NASFPN( | |
input_specs=input_specs, | |
num_filters=num_filters, | |
num_repeats=num_repeats, | |
use_separable_conv=use_separable_conv, | |
use_sync_bn=True) | |
model_config = configs.retinanet.RetinaNet() | |
model_config.min_level = min_level | |
model_config.max_level = max_level | |
model_config.num_classes = 10 | |
model_config.input_size = [None, None, 3] | |
model_config.decoder = decoders_cfg.Decoder( | |
type='nasfpn', | |
nasfpn=decoders_cfg.NASFPN( | |
num_filters=num_filters, | |
num_repeats=num_repeats, | |
use_separable_conv=use_separable_conv)) | |
factory_network = factory.build_decoder( | |
input_specs=input_specs, model_config=model_config) | |
network_config = network.get_config() | |
factory_network_config = factory_network.get_config() | |
self.assertEqual(network_config, factory_network_config) | |
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters): | |
"""Test creation of ASPP decoder.""" | |
input_specs = {'1': tf.TensorShape([1, 128, 128, 3])} | |
network = decoders.ASPP( | |
level=level, | |
dilation_rates=dilation_rates, | |
num_filters=num_filters, | |
use_sync_bn=True) | |
model_config = configs.semantic_segmentation.SemanticSegmentationModel() | |
model_config.num_classes = 10 | |
model_config.input_size = [None, None, 3] | |
model_config.decoder = decoders_cfg.Decoder( | |
type='aspp', | |
aspp=decoders_cfg.ASPP( | |
level=level, dilation_rates=dilation_rates, | |
num_filters=num_filters)) | |
factory_network = factory.build_decoder( | |
input_specs=input_specs, model_config=model_config) | |
network_config = network.get_config() | |
factory_network_config = factory_network.get_config() | |
# Due to calling `super().get_config()` in aspp layer, everything but the | |
# the name of two layer instances are the same, so we force equal name so it | |
# will not give false alarm. | |
factory_network_config['name'] = network_config['name'] | |
self.assertEqual(network_config, factory_network_config) | |
def test_identity_decoder_creation(self): | |
"""Test creation of identity decoder.""" | |
model_config = configs.retinanet.RetinaNet() | |
model_config.num_classes = 2 | |
model_config.input_size = [None, None, 3] | |
model_config.decoder = decoders_cfg.Decoder( | |
type='identity', identity=decoders_cfg.Identity()) | |
factory_network = factory.build_decoder( | |
input_specs=None, model_config=model_config) | |
self.assertIsNone(factory_network) | |
if __name__ == '__main__': | |
tf.test.main() | |