gcvit-tf / gcvit /models /gcvit.py
awsaf49's picture
fix conflict
4092407
import numpy as np
import tensorflow as tf
from ..layers import Stem, GCViTLevel, Identity
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
TAG = 'v1.1.1'
NAME2CONFIG = {
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
'dim': 64,
'depths': (2, 2, 6, 2),
'num_heads': (2, 4, 8, 16),
'mlp_ratio': 3.,
'path_drop': 0.2},
'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
'dim': 64,
'depths': (3, 4, 6, 5),
'num_heads': (2, 4, 8, 16),
'mlp_ratio': 3.,
'path_drop': 0.2},
'gcvit_tiny': {'window_size': (7, 7, 14, 7),
'dim': 64,
'depths': (3, 4, 19, 5),
'num_heads': (2, 4, 8, 16),
'mlp_ratio': 3.,
'num_heads': (2, 4, 8, 16),
'mlp_ratio': 3.,
'path_drop': 0.2,},
'gcvit_small': {'window_size': (7, 7, 14, 7),
'dim': 96,
'depths': (3, 4, 19, 5),
'num_heads': (3, 6, 12, 24),
'mlp_ratio': 2.,
'path_drop': 0.3,
'layer_scale': 1e-5,},
'gcvit_base': {'window_size': (7, 7, 14, 7),
'dim':128,
'depths': (3, 4, 19, 5),
'num_heads': (4, 8, 16, 32),
'mlp_ratio': 2.,
'path_drop': 0.5,
'layer_scale': 1e-5,},
'gcvit_large': {'window_size': (7, 7, 14, 7),
'dim':192,
'depths': (3, 4, 19, 5),
'num_heads': (6, 12, 24, 48),
'mlp_ratio': 2.,
'path_drop': 0.5,
'layer_scale': 1e-5,},
}
@tf.keras.utils.register_keras_serializable(package='gcvit')
class GCViT(tf.keras.Model):
def __init__(self, window_size, dim, depths, num_heads,
drop_rate=0., mlp_ratio=3., qkv_bias=True, qk_scale=None, attn_drop=0., path_drop=0.1, layer_scale=None, resize_query=False,
global_pool='avg', num_classes=1000, head_act='softmax', **kwargs):
super().__init__(**kwargs)
self.window_size = window_size
self.dim = dim
self.depths = depths
self.num_heads = num_heads
self.drop_rate = drop_rate
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.attn_drop = attn_drop
self.path_drop = path_drop
self.layer_scale = layer_scale
self.resize_query = resize_query
self.global_pool = global_pool
self.num_classes = num_classes
self.head_act = head_act
self.patch_embed = Stem(dim=dim, name='patch_embed')
self.patch_embed = Stem(dim=dim, name='patch_embed')
self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
path_drops = np.linspace(0., path_drop, sum(depths))
keep_dims = [(False, False, False),(False, False),(True,),(True,),]
self.levels = []
for i in range(len(depths)):
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
name=f'levels/{i}')
self.levels.append(level)
self.norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm')
if global_pool == 'avg':
self.pool = tf.keras.layers.GlobalAveragePooling2D(name='pool')
elif global_pool == 'max':
self.pool = tf.keras.layers.GlobalMaxPooling2D(name='pool')
elif global_pool is None:
self.pool = Identity(name='pool')
else:
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
def forward_features(self, inputs):
x = self.patch_embed(inputs)
x = self.pos_drop(x)
x = tf.cast(x, dtype=tf.float32)
for level in self.levels:
x = level(x)
x = self.norm(x)
return x
def forward_head(self, inputs, pre_logits=False):
x = inputs
if self.global_pool in ['avg', 'max']:
x = self.pool(x)
if not pre_logits:
x = self.head(x)
return x
def call(self, inputs, **kwargs):
x = self.forward_features(inputs)
x = self.forward_head(x)
return x
def build_graph(self, input_shape=(224, 224, 3)):
"""https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"""
x = tf.keras.Input(shape=input_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
def summary(self, input_shape=(224, 224, 3)):
return self.build_graph(input_shape).summary()
# load standard models
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_xxtiny'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_xtiny'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_xxtiny'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_xtiny'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_tiny'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_small'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_base'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model
def GCViTLarge(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
name = 'gcvit_large'
config = NAME2CONFIG[name]
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
if pretrain:
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
model.load_weights(ckpt_path)
return model