Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import copy | |
from functools import partial | |
from typing import Optional, Union | |
import torch | |
from torch import nn | |
from fourm.utils.timm.registry import register_model | |
from huggingface_hub import PyTorchModelHubMixin | |
from .encoder_embeddings import ImageEncoderEmbedding | |
from .fm_utils import Block, LayerNorm | |
from fourm.data.modality_info import MODALITY_INFO | |
__all__ = [ | |
# GELU models | |
'fm_vit_tiny_6e_gelu', | |
'fm_vit_small_8e_gelu', | |
'fm_vit_base_12e_gelu', | |
'fm_vit_large_24e_gelu', | |
'fm_vit_xlarge_24e_gelu', | |
# SwiGLU models | |
'fm_vit_tiny_6e_swiglu_nobias', | |
'fm_vit_small_8e_swiglu_nobias', | |
'fm_vit_base_12e_swiglu_nobias', | |
'fm_vit_large_24e_swiglu_nobias', | |
'fm_vit_xlarge_24e_swiglu_nobias', | |
# SwiGLU + QKNorm models | |
'fm_vit_base_12e_swiglu_qknorm_nobias', | |
'fm_vit_large_24e_swiglu_qknorm_nobias', | |
'fm_vit_xlarge_24e_swiglu_qknorm_nobias', | |
] | |
class FourMViT(nn.Module): | |
"""Modified 4M model, adapted to behave as a simple RGB-only ViT. | |
Args: | |
img_size (int): Input image size. | |
patch_size (int): Patch size. | |
in_chans (int): Number of input image channels. | |
dim (int): Patch embedding dimension. | |
encoder_depth (int): Depth of ViT / number of encoder blocks. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
proj_bias (bool): If True, adds a bias to the attention out proj layer. | |
mlp_bias (bool): If True, adds a learnable bias for the feedforward. | |
drop_path_rate (float): Stochastic depth rate. | |
drop_rate (float): Dropout rate. | |
attn_drop_rate (float): Attention dropout rate. | |
act_layer (nn.Module): Activation layer. | |
norm_layer (nn.Module): Normalization layer. | |
gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU) | |
qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B) | |
use_act_checkpoint (bool): If True, use activation checkpointing. | |
encoder_norm (bool): If True, adds a norm layer after the last encoder block. | |
output_head (Optional[nn.Module]): Optional output head after the encoder | |
""" | |
def __init__( | |
self, | |
img_size=224, | |
patch_size=16, | |
in_chans=3, | |
dim=768, | |
encoder_depth=12, | |
num_heads=12, | |
mlp_ratio=4.0, | |
qkv_bias: bool = True, | |
proj_bias: bool = True, | |
mlp_bias: bool = True, | |
drop_path_rate: float =0.0, | |
drop_rate: float = 0.0, | |
attn_drop_rate: float =0.0, | |
act_layer: torch.Tensor =nn.GELU, | |
norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6), | |
gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU | |
qk_norm: bool = False, | |
encoder_norm = True, | |
output_head: Optional[nn.Module] = None, | |
): | |
super().__init__() | |
self.img_size = img_size | |
self.init_std = 0.02 | |
rgb_embedding = ImageEncoderEmbedding(num_channels=in_chans, patch_size=patch_size, | |
dim_tokens=dim, sincos_pos_emb=True, image_size=img_size) | |
self.num_patches = rgb_embedding.num_patches | |
self.encoder_embeddings = nn.ModuleDict({f"rgb@{img_size}": rgb_embedding}) | |
# stochastic depth decay rule | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)] | |
self.encoder = nn.ModuleList([ | |
Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, | |
drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, | |
gated_mlp=gated_mlp, qk_norm=qk_norm) | |
for i in range(encoder_depth) | |
]) | |
self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity() | |
# Weight init | |
self.init_weights() | |
# Classification head is initialized after init_weights() to allow for special init scale | |
if output_head is not None: | |
self.output_head = output_head | |
if hasattr(self.output_head, 'init'): | |
self.output_head.init(dim) | |
else: | |
self.output_head = nn.Identity() | |
def init_weights(self): | |
"""Weight initialization following MAE's initialization scheme""" | |
for name, m in self.named_modules(): | |
# Skipping tokenizers to avoid reinitializing them | |
if "tokenizer" in name: | |
continue | |
# Linear | |
elif isinstance(m, nn.Linear): | |
if 'qkv' in name: | |
# treat the weights of Q, K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
elif 'kv' in name: | |
# treat the weights of K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
else: | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
# LayerNorm | |
elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm): | |
nn.init.constant_(m.weight, 1.0) | |
nn.init.constant_(m.bias, 0) | |
# Embedding | |
elif isinstance(m, nn.Embedding): | |
nn.init.normal_(m.weight, std=self.init_std) | |
# Conv2d | |
elif isinstance(m, nn.Conv2d): | |
if '.proj' in name: | |
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
w = m.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
def get_num_layers_encoder(self): | |
return len(self.encoder) | |
def get_num_layers(self): | |
return self.get_num_layers_encoder() | |
def no_weight_decay(self): | |
no_wd_set = set() | |
for mod, emb_module in self.encoder_embeddings.items(): | |
if hasattr(emb_module, 'no_weight_decay'): | |
to_skip = emb_module.no_weight_decay() | |
to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip]) | |
no_wd_set = no_wd_set | to_skip | |
return no_wd_set | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the model. | |
Args: | |
x (torch.Tensor): Input tensor. Shape (B, C, H, W) | |
Returns: | |
torch.Tensor: Output tensor. Shape (B, num_classes). | |
""" | |
rgb_dict = {'tensor': x} | |
rgb_dict = self.encoder_embeddings[f'rgb@{self.img_size}'](rgb_dict) | |
# Add embeddings to patchified RGB image | |
x = rgb_dict['x'] + rgb_dict['emb'] # Shape: (B, N, D) with N = num_patches | |
for blk in self.encoder: | |
x = blk(x) | |
x = self.encoder_norm(x) # Shape: (B, N, D) | |
out = self.output_head(x) | |
return out | |
def freeze_encoder(self, freeze_embeddings=True): | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
for param in self.encoder_norm.parameters(): | |
param.requires_grad = False | |
if freeze_embeddings: | |
for param in self.encoder_embeddings.parameters(): | |
param.requires_grad = False | |
def unfreeze_encoder(self, unfreeze_embeddings=True): | |
for param in self.encoder.parameters(): | |
param.requires_grad = True | |
for param in self.encoder_norm.parameters(): | |
param.requires_grad = True | |
if unfreeze_embeddings: | |
for param in self.encoder_embeddings.parameters(): | |
param.requires_grad = True | |
################################################ | |
# Wrapper for easy loading with Huggingface Hub | |
class FMViT(FourMViT, PyTorchModelHubMixin): | |
"""Wrapper around FourMViT for easy loading with Huggingface Hub. | |
Args: | |
config (dict): Dictionary containing the model and modality configuration, | |
used for loading from Huggingface Hub. | |
output_head (nn.Module): Optional output head. | |
""" | |
def __init__(self, config: dict, output_head: Optional[nn.Module] = None): | |
config = copy.deepcopy(config) | |
config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias']) | |
config['act_layer'] = getattr(torch.nn, config['act_layer']) | |
img_size = config['image_size'] | |
config['img_size'] = img_size | |
config['patch_size'] = MODALITY_INFO[f'rgb@{img_size}'].get('patch_size', config['patch_size']) | |
config['in_chans'] = MODALITY_INFO[f'rgb@{img_size}'].get('num_channels', 3) | |
for key in ['image_size', 'norm_bias', 'domains_in', 'domains_out', 'decoder_depth', 'share_modality_embeddings']: | |
if key in config: | |
del config[key] | |
super().__init__( | |
output_head=output_head, | |
**config | |
) | |
################################################ | |
# Model definitions | |
# GELU variants | |
def fm_vit_tiny_6e_gelu(**kwargs): | |
model = FourMViT( | |
encoder_depth=6, | |
dim=384, | |
num_heads=6, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def fm_vit_small_8e_gelu(**kwargs): | |
model = FourMViT( | |
encoder_depth=8, | |
dim=512, | |
num_heads=8, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def fm_vit_base_12e_gelu(**kwargs): | |
model = FourMViT( | |
encoder_depth=12, | |
dim=768, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def fm_vit_large_24e_gelu(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=1024, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def fm_vit_xlarge_24e_gelu(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=2048, | |
num_heads=32, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
# SwiGLU variants | |
def fm_vit_tiny_6e_swiglu_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=6, | |
dim=384, | |
num_heads=6, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_small_8e_swiglu_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=8, | |
dim=512, | |
num_heads=8, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_base_12e_swiglu_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=12, | |
dim=768, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_large_24e_swiglu_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=1024, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_xlarge_24e_swiglu_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=2048, | |
num_heads=32, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
**kwargs | |
) | |
return model | |
# SwiGLU + QKNorm variants | |
def fm_vit_base_12e_swiglu_qknorm_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=12, | |
dim=768, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
qk_norm=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_large_24e_swiglu_qknorm_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=1024, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
qk_norm=True, | |
**kwargs | |
) | |
return model | |
def fm_vit_xlarge_24e_swiglu_qknorm_nobias(**kwargs): | |
model = FourMViT( | |
encoder_depth=24, | |
dim=2048, | |
num_heads=32, | |
mlp_ratio=4, | |
qkv_bias=False, | |
proj_bias=False, | |
mlp_bias=False, | |
norm_layer=partial(LayerNorm, eps=1e-6, bias=False), | |
act_layer=nn.SiLU, | |
gated_mlp=True, | |
qk_norm=True, | |
**kwargs | |
) | |
return model |