deanna-emery's picture
updates
93528c6
raw
history blame
5.63 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for decoder factory functions."""
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from official.vision import configs
from official.vision.configs import decoders as decoders_cfg
from official.vision.modeling import decoders
from official.vision.modeling.decoders import factory
class FactoryTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
num_filters=[128, 256], use_separable_conv=[True, False]))
def test_fpn_decoder_creation(self, num_filters, use_separable_conv):
"""Test creation of FPN decoder."""
min_level = 3
max_level = 7
input_specs = {}
for level in range(min_level, max_level):
input_specs[str(level)] = tf.TensorShape(
[1, 128 // (2**level), 128 // (2**level), 3])
network = decoders.FPN(
input_specs=input_specs,
num_filters=num_filters,
use_separable_conv=use_separable_conv,
use_sync_bn=True)
model_config = configs.retinanet.RetinaNet()
model_config.min_level = min_level
model_config.max_level = max_level
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='fpn',
fpn=decoders_cfg.FPN(
num_filters=num_filters, use_separable_conv=use_separable_conv))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
@combinations.generate(
combinations.combine(
num_filters=[128, 256],
num_repeats=[3, 5],
use_separable_conv=[True, False]))
def test_nasfpn_decoder_creation(self, num_filters, num_repeats,
use_separable_conv):
"""Test creation of NASFPN decoder."""
min_level = 3
max_level = 7
input_specs = {}
for level in range(min_level, max_level):
input_specs[str(level)] = tf.TensorShape(
[1, 128 // (2**level), 128 // (2**level), 3])
network = decoders.NASFPN(
input_specs=input_specs,
num_filters=num_filters,
num_repeats=num_repeats,
use_separable_conv=use_separable_conv,
use_sync_bn=True)
model_config = configs.retinanet.RetinaNet()
model_config.min_level = min_level
model_config.max_level = max_level
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='nasfpn',
nasfpn=decoders_cfg.NASFPN(
num_filters=num_filters,
num_repeats=num_repeats,
use_separable_conv=use_separable_conv))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
@combinations.generate(
combinations.combine(
level=[3, 4],
dilation_rates=[[6, 12, 18], [6, 12]],
num_filters=[128, 256]))
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
"""Test creation of ASPP decoder."""
input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}
network = decoders.ASPP(
level=level,
dilation_rates=dilation_rates,
num_filters=num_filters,
use_sync_bn=True)
model_config = configs.semantic_segmentation.SemanticSegmentationModel()
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='aspp',
aspp=decoders_cfg.ASPP(
level=level, dilation_rates=dilation_rates,
num_filters=num_filters))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
# Due to calling `super().get_config()` in aspp layer, everything but the
# the name of two layer instances are the same, so we force equal name so it
# will not give false alarm.
factory_network_config['name'] = network_config['name']
self.assertEqual(network_config, factory_network_config)
def test_identity_decoder_creation(self):
"""Test creation of identity decoder."""
model_config = configs.retinanet.RetinaNet()
model_config.num_classes = 2
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='identity', identity=decoders_cfg.Identity())
factory_network = factory.build_decoder(
input_specs=None, model_config=model_config)
self.assertIsNone(factory_network)
if __name__ == '__main__':
tf.test.main()