Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""VisionTransformer models.""" | |
import math | |
from typing import Optional, Tuple | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.modeling import activations | |
from official.vision.modeling.backbones import factory | |
from official.vision.modeling.backbones.vit_specs import VIT_SPECS | |
from official.vision.modeling.layers import nn_blocks | |
from official.vision.modeling.layers import nn_layers | |
layers = tf_keras.layers | |
class AddPositionEmbs(layers.Layer): | |
"""Adds (optionally learned) positional embeddings to the inputs.""" | |
def __init__(self, | |
posemb_init: Optional[tf_keras.initializers.Initializer] = None, | |
posemb_origin_shape: Optional[Tuple[int, int]] = None, | |
posemb_target_shape: Optional[Tuple[int, int]] = None, | |
**kwargs): | |
"""Constructs Positional Embedding module. | |
The logic of this module is: the learnable positional embeddings length will | |
be determined by the inputs_shape or posemb_origin_shape (if provided) | |
during the construction. If the posemb_target_shape is provided and is | |
different from the positional embeddings length, the embeddings will be | |
interpolated during the forward call. | |
Args: | |
posemb_init: The positional embedding initializer. | |
posemb_origin_shape: The intended positional embedding shape. | |
posemb_target_shape: The potential target shape positional embedding may | |
be interpolated to. | |
**kwargs: other args. | |
""" | |
super().__init__(**kwargs) | |
self.posemb_init = posemb_init | |
self.posemb_origin_shape = posemb_origin_shape | |
self.posemb_target_shape = posemb_target_shape | |
def build(self, inputs_shape): | |
if self.posemb_origin_shape is not None: | |
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1] | |
else: | |
pos_emb_length = inputs_shape[1] | |
pos_emb_shape = (1, pos_emb_length, inputs_shape[2]) | |
self.pos_embedding = self.add_weight( | |
'pos_embedding', pos_emb_shape, initializer=self.posemb_init) | |
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int], | |
to_shape: Tuple[int, int]) -> tf.Tensor: | |
"""Interpolates the positional embeddings.""" | |
logging.info('Interpolating postional embedding from length: %s to %s', | |
from_shape, to_shape) | |
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1]) | |
# NOTE: Using BILINEAR interpolation by default. | |
grid_emb = tf.image.resize(grid_emb, to_shape) | |
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1]) | |
def call(self, inputs, inputs_positions=None): | |
del inputs_positions | |
pos_embedding = self.pos_embedding | |
# inputs.shape is (batch_size, seq_len, emb_dim). | |
if inputs.shape[1] != pos_embedding.shape[1]: | |
pos_embedding = self._interpolate( | |
pos_embedding, | |
from_shape=self.posemb_origin_shape, | |
to_shape=self.posemb_target_shape) | |
pos_embedding = tf.cast(pos_embedding, inputs.dtype) | |
return inputs + pos_embedding | |
class TokenLayer(layers.Layer): | |
"""A simple layer to wrap token parameters.""" | |
def build(self, inputs_shape): | |
self.cls = self.add_weight( | |
'cls', (1, 1, inputs_shape[-1]), initializer='zeros') | |
def call(self, inputs): | |
cls = tf.cast(self.cls, inputs.dtype) | |
cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile. | |
x = tf.concat([cls, inputs], axis=1) | |
return x | |
class Encoder(layers.Layer): | |
"""Transformer Encoder.""" | |
def __init__(self, | |
num_layers, | |
mlp_dim, | |
num_heads, | |
dropout_rate=0.1, | |
attention_dropout_rate=0.1, | |
kernel_regularizer=None, | |
inputs_positions=None, | |
init_stochastic_depth_rate=0.0, | |
kernel_initializer='glorot_uniform', | |
add_pos_embed=True, | |
pos_embed_origin_shape=None, | |
pos_embed_target_shape=None, | |
layer_scale_init_value=0.0, | |
transformer_partition_dims=None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self._num_layers = num_layers | |
self._mlp_dim = mlp_dim | |
self._num_heads = num_heads | |
self._dropout_rate = dropout_rate | |
self._attention_dropout_rate = attention_dropout_rate | |
self._kernel_regularizer = kernel_regularizer | |
self._inputs_positions = inputs_positions | |
self._init_stochastic_depth_rate = init_stochastic_depth_rate | |
self._kernel_initializer = kernel_initializer | |
self._add_pos_embed = add_pos_embed | |
self._pos_embed_origin_shape = pos_embed_origin_shape | |
self._pos_embed_target_shape = pos_embed_target_shape | |
self._layer_scale_init_value = layer_scale_init_value | |
self._transformer_partition_dims = transformer_partition_dims | |
def build(self, input_shape): | |
if self._add_pos_embed: | |
self._pos_embed = AddPositionEmbs( | |
posemb_init=tf_keras.initializers.RandomNormal(stddev=0.02), | |
posemb_origin_shape=self._pos_embed_origin_shape, | |
posemb_target_shape=self._pos_embed_target_shape, | |
name='posembed_input') | |
self._dropout = layers.Dropout(rate=self._dropout_rate) | |
self._encoder_layers = [] | |
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. | |
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html | |
for i in range(self._num_layers): | |
encoder_layer = nn_blocks.TransformerEncoderBlock( | |
inner_activation=activations.gelu, | |
num_attention_heads=self._num_heads, | |
inner_dim=self._mlp_dim, | |
output_dropout=self._dropout_rate, | |
attention_dropout=self._attention_dropout_rate, | |
kernel_regularizer=self._kernel_regularizer, | |
kernel_initializer=self._kernel_initializer, | |
norm_first=True, | |
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( | |
self._init_stochastic_depth_rate, i + 1, self._num_layers), | |
norm_epsilon=1e-6, | |
layer_scale_init_value=self._layer_scale_init_value, | |
transformer_partition_dims=self._transformer_partition_dims) | |
self._encoder_layers.append(encoder_layer) | |
self._norm = layers.LayerNormalization(epsilon=1e-6) | |
super().build(input_shape) | |
def call(self, inputs, training=None): | |
x = inputs | |
if self._add_pos_embed: | |
x = self._pos_embed(x, inputs_positions=self._inputs_positions) | |
x = self._dropout(x, training=training) | |
for encoder_layer in self._encoder_layers: | |
x = encoder_layer(x, training=training) | |
x = self._norm(x) | |
return x | |
def get_config(self): | |
config = super().get_config() | |
updates = { | |
'num_layers': self._num_layers, | |
'mlp_dim': self._mlp_dim, | |
'num_heads': self._num_heads, | |
'dropout_rate': self._dropout_rate, | |
'attention_dropout_rate': self._attention_dropout_rate, | |
'kernel_regularizer': self._kernel_regularizer, | |
'inputs_positions': self._inputs_positions, | |
'init_stochastic_depth_rate': self._init_stochastic_depth_rate, | |
'kernel_initializer': self._kernel_initializer, | |
'add_pos_embed': self._add_pos_embed, | |
'pos_embed_origin_shape': self._pos_embed_origin_shape, | |
'pos_embed_target_shape': self._pos_embed_target_shape, | |
'layer_scale_init_value': self._layer_scale_init_value, | |
'transformer_partition_dims': self._transformer_partition_dims, | |
} | |
config.update(updates) | |
return config | |
class VisionTransformer(tf_keras.Model): | |
"""Class to build VisionTransformer family model.""" | |
def __init__( | |
self, | |
mlp_dim=3072, | |
num_heads=12, | |
num_layers=12, | |
attention_dropout_rate=0.0, | |
dropout_rate=0.1, | |
init_stochastic_depth_rate=0.0, | |
input_specs=layers.InputSpec(shape=[None, None, None, 3]), | |
patch_size=16, | |
hidden_size=768, | |
representation_size=0, | |
pooler='token', | |
kernel_regularizer=None, | |
original_init: bool = True, | |
output_encoded_tokens: bool = True, | |
output_2d_feature_maps: bool = False, | |
pos_embed_shape: Optional[Tuple[int, int]] = None, | |
layer_scale_init_value: float = 0.0, | |
transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None, | |
): | |
"""VisionTransformer initialization function.""" | |
self._mlp_dim = mlp_dim | |
self._num_heads = num_heads | |
self._num_layers = num_layers | |
self._hidden_size = hidden_size | |
self._patch_size = patch_size | |
inputs = tf_keras.Input(shape=input_specs.shape[1:]) | |
x = layers.Conv2D( | |
filters=hidden_size, | |
kernel_size=patch_size, | |
strides=patch_size, | |
padding='valid', | |
kernel_regularizer=kernel_regularizer, | |
kernel_initializer='lecun_normal' if original_init else 'he_uniform')( | |
inputs) | |
if tf_keras.backend.image_data_format() == 'channels_last': | |
rows_axis, cols_axis = (1, 2) | |
else: | |
rows_axis, cols_axis = (2, 3) | |
# The reshape below assumes the data_format is 'channels_last,' so | |
# transpose to that. Once the data is flattened by the reshape, the | |
# data_format is irrelevant, so no need to update | |
# tf_keras.backend.image_data_format. | |
x = tf.transpose(x, perm=[0, 2, 3, 1]) | |
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis]) | |
feat_h = input_specs.shape[rows_axis] // patch_size | |
feat_w = input_specs.shape[cols_axis] // patch_size | |
seq_len = feat_h * feat_w | |
x = tf.reshape(x, [-1, seq_len, hidden_size]) | |
# If we want to add a class token, add it here. | |
if pooler == 'token': | |
x = TokenLayer(name='cls')(x) | |
x = Encoder( | |
num_layers=num_layers, | |
mlp_dim=mlp_dim, | |
num_heads=num_heads, | |
dropout_rate=dropout_rate, | |
attention_dropout_rate=attention_dropout_rate, | |
kernel_regularizer=kernel_regularizer, | |
kernel_initializer='glorot_uniform' if original_init else dict( | |
class_name='TruncatedNormal', config=dict(stddev=.02)), | |
init_stochastic_depth_rate=init_stochastic_depth_rate, | |
pos_embed_origin_shape=pos_embed_shape, | |
pos_embed_target_shape=pos_embed_target_shape, | |
layer_scale_init_value=layer_scale_init_value)( | |
x) | |
if pooler == 'token': | |
output_feature = x[:, 1:] | |
x = x[:, 0] | |
elif pooler == 'gap': | |
output_feature = x | |
x = tf.reduce_mean(x, axis=1) | |
elif pooler == 'none': | |
output_feature = x | |
x = tf.identity(x, name='encoded_tokens') | |
else: | |
raise ValueError(f'unrecognized pooler type: {pooler}') | |
endpoints = {} | |
if output_2d_feature_maps: | |
# Use the closest feature level. | |
feat_level = round(math.log2(patch_size)) | |
logging.info( | |
'VisionTransformer patch size %d and feature level: %d', | |
patch_size, | |
feat_level, | |
) | |
endpoints[str(feat_level)] = tf.reshape( | |
output_feature, [-1, feat_h, feat_w, x.shape.as_list()[-1]]) | |
# Don"t include `pre_logits` or `encoded_tokens` to support decoders. | |
self._output_specs = {k: v.shape for k, v in endpoints.items()} | |
if representation_size: | |
x = layers.Dense( | |
representation_size, | |
kernel_regularizer=kernel_regularizer, | |
name='pre_logits', | |
kernel_initializer='lecun_normal' if original_init else 'he_uniform', | |
)(x) | |
x = tf.nn.tanh(x) | |
else: | |
x = tf.identity(x, name='pre_logits') | |
if pooler == 'none': | |
if output_encoded_tokens: | |
endpoints['encoded_tokens'] = x | |
else: | |
endpoints['pre_logits'] = tf.reshape( | |
x, [-1, 1, 1, representation_size or hidden_size]) | |
super().__init__(inputs=inputs, outputs=endpoints) | |
def output_specs(self): | |
"""A dict of {level: TensorShape} pairs for the model output.""" | |
return self._output_specs | |
def build_vit(input_specs, | |
backbone_config, | |
norm_activation_config, | |
l2_regularizer=None): | |
"""Build ViT model.""" | |
del norm_activation_config | |
backbone_type = backbone_config.type | |
backbone_cfg = backbone_config.get() | |
assert backbone_type == 'vit', (f'Inconsistent backbone type ' | |
f'{backbone_type}') | |
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name]) | |
logging.info( | |
( | |
'ViT specs: mlp_dim=%d, num_heads=%d, num_layers=%d,' | |
'patch_size=%d, hidden_size=%d, representation_size=%d.' | |
), | |
backbone_cfg.transformer.mlp_dim, | |
backbone_cfg.transformer.num_heads, | |
backbone_cfg.transformer.num_layers, | |
backbone_cfg.patch_size, | |
backbone_cfg.hidden_size, | |
backbone_cfg.representation_size, | |
) | |
return VisionTransformer( | |
mlp_dim=backbone_cfg.transformer.mlp_dim, | |
num_heads=backbone_cfg.transformer.num_heads, | |
num_layers=backbone_cfg.transformer.num_layers, | |
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate, | |
dropout_rate=backbone_cfg.transformer.dropout_rate, | |
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate, | |
input_specs=input_specs, | |
patch_size=backbone_cfg.patch_size, | |
hidden_size=backbone_cfg.hidden_size, | |
representation_size=backbone_cfg.representation_size, | |
pooler=backbone_cfg.pooler, | |
kernel_regularizer=l2_regularizer, | |
original_init=backbone_cfg.original_init, | |
output_encoded_tokens=backbone_cfg.output_encoded_tokens, | |
output_2d_feature_maps=backbone_cfg.output_2d_feature_maps, | |
layer_scale_init_value=backbone_cfg.layer_scale_init_value, | |
pos_embed_shape=backbone_cfg.pos_embed_shape, | |
transformer_partition_dims=backbone_cfg.transformer_partition_dims) | |