File size: 8,205 Bytes
39b9986 f6c0c44 39b9986 f6c0c44 39b9986 f6c0c44 39b9986 f6c0c44 39b9986 f6c0c44 39b9986 f6c0c44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from timm.models import register_model
from timm.models.vision_transformer import (
VisionTransformer,
_create_vision_transformer as _timm_create_vision_transformer,
Mlp,
Block,
LayerScale as TIMMLayerScale,
)
# Import these to also register them
from . import dinov2_arch
@register_model
def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Tiny (Vit-Ti/16)
"""
model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Small (ViT-S/16)
"""
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
)
model = _create_vision_transformer(
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
name = 'vit_large_patch14_reg4_dinov2'
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
)
model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
if pretrained:
# There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
else:
model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
for m in model.modules():
if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
m.norm = nn.LayerNorm(m.fc1.out_features)
return model
@register_model
def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
""" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
if scaled_ln:
_apply_scaled_ln(model)
return model
@register_model
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
return model
def _create_vision_transformer(*args, **kwargs):
model = _timm_create_vision_transformer(*args, **kwargs)
_patch_layer_scale(model)
return model
def _patch_layer_scale(model: VisionTransformer):
def replace_ls(old_ls: TIMMLayerScale):
new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
new_ls.load_state_dict(old_ls.state_dict())
return new_ls
# Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
# other than gamma, so that HFHub doesn't mess with it!
for mod in model.modules():
if isinstance(mod, Block):
if isinstance(mod.ls1, TIMMLayerScale):
mod.ls1 = replace_ls(mod.ls1)
if isinstance(mod.ls2, TIMMLayerScale):
mod.ls2 = replace_ls(mod.ls2)
pass
class ScaledLayerNorm(nn.LayerNorm):
'''
https://arxiv.org/pdf/2502.05795v1
'''
def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
self.load_state_dict(ln_base.state_dict())
self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
def forward(self, x):
y = super().forward(x)
y = y * self.ln_scale
return y
class DyT(nn.Module):
def __init__(self, C: int, init_alpha: float):
super().__init__()
self.alpha = nn.Parameter(torch.full((1,), init_alpha))
self.gamma = nn.Parameter(torch.ones(C))
self.beta = nn.Parameter(torch.zeros(C))
def forward(self, x: torch.Tensor):
x = F.tanh(self.alpha * x)
return self.gamma * x + self.beta
@register_model
def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
return DyT(ln.normalized_shape[0], init_alpha=0.9)
_replace_ln(model, _replace_ln_with_dyt)
return model
def _apply_scaled_ln(model: VisionTransformer):
warnings.warn('Post-LayerNorm scaling activated!')
_replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
def _replace_ln(model: VisionTransformer, fn):
def _inner_replace_ln(block: Block, depth: int, key: str):
prev = getattr(block, key)
if isinstance(prev, nn.LayerNorm):
setattr(block, key, fn(prev, depth=depth))
for i, block in enumerate(model.blocks):
_inner_replace_ln(block, i + 1, 'norm1')
_inner_replace_ln(block, i + 1, 'norm2')
|