deanna-emery's picture
updates
93528c6
raw
history blame
3.46 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for VIT."""
import math
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from official.vision.modeling.backbones import vit
class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, 85798656),
(256, 85844736),
)
def test_network_creation(self, input_size, params_count):
"""Test creation of VisionTransformer family models."""
tf_keras.backend.set_image_data_format('channels_last')
input_specs = tf_keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(input_specs=input_specs)
inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
@parameterized.product(
patch_size=[6, 4],
output_2d_feature_maps=[True, False],
pooler=['none', 'gap', 'token'],
)
def test_network_with_diferent_configs(
self, patch_size, output_2d_feature_maps, pooler):
tf_keras.backend.set_image_data_format('channels_last')
input_size = 24
expected_feat_level = str(round(math.log2(patch_size)))
num_patch_rows = input_size // patch_size
input_specs = tf_keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=patch_size,
pooler=pooler,
hidden_size=8,
mlp_dim=8,
num_layers=1,
num_heads=2,
representation_size=16,
output_2d_feature_maps=output_2d_feature_maps)
inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)
if pooler == 'none':
self.assertEqual(
output['encoded_tokens'].shape, [1, num_patch_rows**2, 16])
else:
self.assertEqual(output['pre_logits'].shape, [1, 1, 1, 16])
if output_2d_feature_maps:
self.assertIn(expected_feat_level, output)
self.assertIn(expected_feat_level, network.output_specs)
self.assertEqual(
network.output_specs[expected_feat_level][1:],
[num_patch_rows, num_patch_rows, 8])
else:
self.assertNotIn(expected_feat_level, output)
def test_posembedding_interpolation(self):
tf_keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf_keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf_keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__':
tf.test.main()