deanna-emery's picture
updates
93528c6
raw
history blame
5.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from official.vision.configs import backbones
from official.vision.configs import backbones_3d
from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import maskrcnn as maskrcnn_cfg
from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import video_classification as video_classification_cfg
from official.vision.modeling import factory
from official.vision.modeling import factory_3d
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('resnet', (224, 224), 5e-5),
('resnet', (224, 224), None),
('resnet', (None, None), 5e-5),
('resnet', (None, None), None),
)
def test_builder(self, backbone_type, input_size, weight_decay):
num_classes = 2
input_specs = tf_keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
model_config = classification_cfg.ImageClassificationModel(
num_classes=num_classes,
backbone=backbones.Backbone(type=backbone_type))
l2_regularizer = (
tf_keras.regularizers.l2(weight_decay) if weight_decay else None)
_ = factory.build_classification_model(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
class MaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('resnet', (640, 640)),
('resnet', (None, None)),
)
def test_builder(self, backbone_type, input_size):
num_classes = 2
input_specs = tf_keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
model_config = maskrcnn_cfg.MaskRCNN(
num_classes=num_classes,
backbone=backbones.Backbone(type=backbone_type))
l2_regularizer = tf_keras.regularizers.l2(5e-5)
_ = factory.build_maskrcnn(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('resnet', (640, 640), False),
('resnet', (None, None), True),
)
def test_builder(self, backbone_type, input_size, has_att_heads):
num_classes = 2
input_specs = tf_keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
if has_att_heads:
attribute_heads_config = [
retinanet_cfg.AttributeHead(name='att1'),
retinanet_cfg.AttributeHead(
name='att2', type='classification', size=2),
]
else:
attribute_heads_config = None
model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes,
backbone=backbones.Backbone(type=backbone_type),
head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config))
l2_regularizer = tf_keras.regularizers.l2(5e-5)
_ = factory.build_retinanet(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if has_att_heads:
self.assertEqual(
model_config.head.attribute_heads[0].as_dict(),
dict(
name='att1',
type='regression',
size=1,
prediction_tower_name='',
num_convs=None,
num_filters=None,
),
)
self.assertEqual(
model_config.head.attribute_heads[1].as_dict(),
dict(
name='att2',
type='classification',
size=2,
prediction_tower_name='',
num_convs=None,
num_filters=None,
),
)
class VideoClassificationModelBuilderTest(parameterized.TestCase,
tf.test.TestCase):
@parameterized.parameters(
('resnet_3d', (8, 224, 224), 5e-5),
('resnet_3d', (None, None, None), 5e-5),
)
def test_builder(self, backbone_type, input_size, weight_decay):
input_specs = tf_keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], input_size[2], 3])
model_config = video_classification_cfg.VideoClassificationModel(
backbone=backbones_3d.Backbone3D(type=backbone_type))
l2_regularizer = (
tf_keras.regularizers.l2(weight_decay) if weight_decay else None)
_ = factory_3d.build_video_classification_model(
input_specs=input_specs,
model_config=model_config,
num_classes=2,
l2_regularizer=l2_regularizer)
if __name__ == '__main__':
tf.test.main()