Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for factory.py.""" | |
# Import libraries | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from official.vision.configs import backbones | |
from official.vision.configs import backbones_3d | |
from official.vision.configs import image_classification as classification_cfg | |
from official.vision.configs import maskrcnn as maskrcnn_cfg | |
from official.vision.configs import retinanet as retinanet_cfg | |
from official.vision.configs import video_classification as video_classification_cfg | |
from official.vision.modeling import factory | |
from official.vision.modeling import factory_3d | |
class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): | |
def test_builder(self, backbone_type, input_size, weight_decay): | |
num_classes = 2 | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[None, input_size[0], input_size[1], 3]) | |
model_config = classification_cfg.ImageClassificationModel( | |
num_classes=num_classes, | |
backbone=backbones.Backbone(type=backbone_type)) | |
l2_regularizer = ( | |
tf_keras.regularizers.l2(weight_decay) if weight_decay else None) | |
_ = factory.build_classification_model( | |
input_specs=input_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
class MaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase): | |
def test_builder(self, backbone_type, input_size): | |
num_classes = 2 | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[None, input_size[0], input_size[1], 3]) | |
model_config = maskrcnn_cfg.MaskRCNN( | |
num_classes=num_classes, | |
backbone=backbones.Backbone(type=backbone_type)) | |
l2_regularizer = tf_keras.regularizers.l2(5e-5) | |
_ = factory.build_maskrcnn( | |
input_specs=input_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): | |
def test_builder(self, backbone_type, input_size, has_att_heads): | |
num_classes = 2 | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[None, input_size[0], input_size[1], 3]) | |
if has_att_heads: | |
attribute_heads_config = [ | |
retinanet_cfg.AttributeHead(name='att1'), | |
retinanet_cfg.AttributeHead( | |
name='att2', type='classification', size=2), | |
] | |
else: | |
attribute_heads_config = None | |
model_config = retinanet_cfg.RetinaNet( | |
num_classes=num_classes, | |
backbone=backbones.Backbone(type=backbone_type), | |
head=retinanet_cfg.RetinaNetHead( | |
attribute_heads=attribute_heads_config)) | |
l2_regularizer = tf_keras.regularizers.l2(5e-5) | |
_ = factory.build_retinanet( | |
input_specs=input_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
if has_att_heads: | |
self.assertEqual( | |
model_config.head.attribute_heads[0].as_dict(), | |
dict( | |
name='att1', | |
type='regression', | |
size=1, | |
prediction_tower_name='', | |
num_convs=None, | |
num_filters=None, | |
), | |
) | |
self.assertEqual( | |
model_config.head.attribute_heads[1].as_dict(), | |
dict( | |
name='att2', | |
type='classification', | |
size=2, | |
prediction_tower_name='', | |
num_convs=None, | |
num_filters=None, | |
), | |
) | |
class VideoClassificationModelBuilderTest(parameterized.TestCase, | |
tf.test.TestCase): | |
def test_builder(self, backbone_type, input_size, weight_decay): | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[None, input_size[0], input_size[1], input_size[2], 3]) | |
model_config = video_classification_cfg.VideoClassificationModel( | |
backbone=backbones_3d.Backbone3D(type=backbone_type)) | |
l2_regularizer = ( | |
tf_keras.regularizers.l2(weight_decay) if weight_decay else None) | |
_ = factory_3d.build_video_classification_model( | |
input_specs=input_specs, | |
model_config=model_config, | |
num_classes=2, | |
l2_regularizer=l2_regularizer) | |
if __name__ == '__main__': | |
tf.test.main() | |