|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from functools import partial |
|
from typing import Tuple, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ |
|
ClassifierHead |
|
from ._builder import build_model_with_cfg |
|
from ._manipulate import checkpoint_seq |
|
from ._registry import register_model, generate_default_cfgs |
|
|
|
|
|
def num_groups(group_size, channels): |
|
if not group_size: |
|
return 1 |
|
else: |
|
|
|
assert channels % group_size == 0 |
|
return channels // group_size |
|
|
|
|
|
class MobileOneBlock(nn.Module): |
|
"""MobileOne building block. |
|
|
|
This block has a multi-branched architecture at train-time |
|
and plain-CNN style architecture at inference time |
|
For more details, please refer to our paper: |
|
`An Improved One millisecond Mobile Backbone` - |
|
https://arxiv.org/pdf/2206.04040.pdf |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
group_size: int = 0, |
|
inference_mode: bool = False, |
|
use_se: bool = False, |
|
use_act: bool = True, |
|
use_scale_branch: bool = True, |
|
num_conv_branches: int = 1, |
|
act_layer: nn.Module = nn.GELU, |
|
) -> None: |
|
"""Construct a MobileOneBlock module. |
|
|
|
Args: |
|
in_chs: Number of channels in the input. |
|
out_chs: Number of channels produced by the block. |
|
kernel_size: Size of the convolution kernel. |
|
stride: Stride size. |
|
dilation: Kernel dilation factor. |
|
group_size: Convolution group size. |
|
inference_mode: If True, instantiates model in inference mode. |
|
use_se: Whether to use SE-ReLU activations. |
|
use_act: Whether to use activation. Default: ``True`` |
|
use_scale_branch: Whether to use scale branch. Default: ``True`` |
|
num_conv_branches: Number of linear conv branches. |
|
""" |
|
super(MobileOneBlock, self).__init__() |
|
self.inference_mode = inference_mode |
|
self.groups = num_groups(group_size, in_chs) |
|
self.stride = stride |
|
self.dilation = dilation |
|
self.kernel_size = kernel_size |
|
self.in_chs = in_chs |
|
self.out_chs = out_chs |
|
self.num_conv_branches = num_conv_branches |
|
|
|
|
|
self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity() |
|
|
|
if inference_mode: |
|
self.reparam_conv = create_conv2d( |
|
in_chs, |
|
out_chs, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
else: |
|
|
|
self.reparam_conv = None |
|
|
|
self.identity = ( |
|
nn.BatchNorm2d(num_features=in_chs) |
|
if out_chs == in_chs and stride == 1 |
|
else None |
|
) |
|
|
|
|
|
if num_conv_branches > 0: |
|
self.conv_kxk = nn.ModuleList([ |
|
ConvNormAct( |
|
self.in_chs, |
|
self.out_chs, |
|
kernel_size=kernel_size, |
|
stride=self.stride, |
|
groups=self.groups, |
|
apply_act=False, |
|
) for _ in range(self.num_conv_branches) |
|
]) |
|
else: |
|
self.conv_kxk = None |
|
|
|
|
|
self.conv_scale = None |
|
if kernel_size > 1 and use_scale_branch: |
|
self.conv_scale = ConvNormAct( |
|
self.in_chs, |
|
self.out_chs, |
|
kernel_size=1, |
|
stride=self.stride, |
|
groups=self.groups, |
|
apply_act=False |
|
) |
|
|
|
self.act = act_layer() if use_act else nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Apply forward pass.""" |
|
|
|
if self.reparam_conv is not None: |
|
return self.act(self.se(self.reparam_conv(x))) |
|
|
|
|
|
|
|
identity_out = 0 |
|
if self.identity is not None: |
|
identity_out = self.identity(x) |
|
|
|
|
|
scale_out = 0 |
|
if self.conv_scale is not None: |
|
scale_out = self.conv_scale(x) |
|
|
|
|
|
out = scale_out + identity_out |
|
if self.conv_kxk is not None: |
|
for rc in self.conv_kxk: |
|
out += rc(x) |
|
|
|
return self.act(self.se(out)) |
|
|
|
def reparameterize(self): |
|
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - |
|
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched |
|
architecture used at training time to obtain a plain CNN-like structure |
|
for inference. |
|
""" |
|
if self.reparam_conv is not None: |
|
return |
|
|
|
kernel, bias = self._get_kernel_bias() |
|
self.reparam_conv = create_conv2d( |
|
in_channels=self.in_chs, |
|
out_channels=self.out_chs, |
|
kernel_size=self.kernel_size, |
|
stride=self.stride, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
self.reparam_conv.weight.data = kernel |
|
self.reparam_conv.bias.data = bias |
|
|
|
|
|
for name, para in self.named_parameters(): |
|
if 'reparam_conv' in name: |
|
continue |
|
para.detach_() |
|
|
|
self.__delattr__("conv_kxk") |
|
self.__delattr__("conv_scale") |
|
if hasattr(self, "identity"): |
|
self.__delattr__("identity") |
|
|
|
self.inference_mode = True |
|
|
|
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to obtain re-parameterized kernel and bias. |
|
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing branches. |
|
""" |
|
|
|
kernel_scale = 0 |
|
bias_scale = 0 |
|
if self.conv_scale is not None: |
|
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale) |
|
|
|
pad = self.kernel_size // 2 |
|
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) |
|
|
|
|
|
kernel_identity = 0 |
|
bias_identity = 0 |
|
if self.identity is not None: |
|
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity) |
|
|
|
|
|
kernel_conv = 0 |
|
bias_conv = 0 |
|
if self.conv_kxk is not None: |
|
for ix in range(self.num_conv_branches): |
|
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix]) |
|
kernel_conv += _kernel |
|
bias_conv += _bias |
|
|
|
kernel_final = kernel_conv + kernel_scale + kernel_identity |
|
bias_final = bias_conv + bias_scale + bias_identity |
|
return kernel_final, bias_final |
|
|
|
def _fuse_bn_tensor( |
|
self, branch: Union[nn.Sequential, nn.BatchNorm2d] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to fuse batchnorm layer with preceeding conv layer. |
|
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 |
|
|
|
Args: |
|
branch: Sequence of ops to be fused. |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing batchnorm. |
|
""" |
|
if isinstance(branch, ConvNormAct): |
|
kernel = branch.conv.weight |
|
running_mean = branch.bn.running_mean |
|
running_var = branch.bn.running_var |
|
gamma = branch.bn.weight |
|
beta = branch.bn.bias |
|
eps = branch.bn.eps |
|
else: |
|
assert isinstance(branch, nn.BatchNorm2d) |
|
if not hasattr(self, "id_tensor"): |
|
input_dim = self.in_chs // self.groups |
|
kernel_value = torch.zeros( |
|
(self.in_chs, input_dim, self.kernel_size, self.kernel_size), |
|
dtype=branch.weight.dtype, |
|
device=branch.weight.device, |
|
) |
|
for i in range(self.in_chs): |
|
kernel_value[ |
|
i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 |
|
] = 1 |
|
self.id_tensor = kernel_value |
|
kernel = self.id_tensor |
|
running_mean = branch.running_mean |
|
running_var = branch.running_var |
|
gamma = branch.weight |
|
beta = branch.bias |
|
eps = branch.eps |
|
std = (running_var + eps).sqrt() |
|
t = (gamma / std).reshape(-1, 1, 1, 1) |
|
return kernel * t, beta - running_mean * gamma / std |
|
|
|
|
|
class ReparamLargeKernelConv(nn.Module): |
|
"""Building Block of RepLKNet |
|
|
|
This class defines overparameterized large kernel conv block |
|
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_ |
|
|
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
kernel_size: int, |
|
stride: int, |
|
group_size: int, |
|
small_kernel: Optional[int] = None, |
|
inference_mode: bool = False, |
|
act_layer: Optional[nn.Module] = None, |
|
) -> None: |
|
"""Construct a ReparamLargeKernelConv module. |
|
|
|
Args: |
|
in_chs: Number of input channels. |
|
out_chs: Number of output channels. |
|
kernel_size: Kernel size of the large kernel conv branch. |
|
stride: Stride size. Default: 1 |
|
group_size: Group size. Default: 1 |
|
small_kernel: Kernel size of small kernel conv branch. |
|
inference_mode: If True, instantiates model in inference mode. Default: ``False`` |
|
act_layer: Activation module. Default: ``nn.GELU`` |
|
""" |
|
super(ReparamLargeKernelConv, self).__init__() |
|
self.stride = stride |
|
self.groups = num_groups(group_size, in_chs) |
|
self.in_chs = in_chs |
|
self.out_chs = out_chs |
|
|
|
self.kernel_size = kernel_size |
|
self.small_kernel = small_kernel |
|
if inference_mode: |
|
self.reparam_conv = create_conv2d( |
|
in_chs, |
|
out_chs, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=1, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
else: |
|
self.reparam_conv = None |
|
self.large_conv = ConvNormAct( |
|
in_chs, |
|
out_chs, |
|
kernel_size=kernel_size, |
|
stride=self.stride, |
|
groups=self.groups, |
|
apply_act=False, |
|
) |
|
if small_kernel is not None: |
|
assert ( |
|
small_kernel <= kernel_size |
|
), "The kernel size for re-param cannot be larger than the large kernel!" |
|
self.small_conv = ConvNormAct( |
|
in_chs, |
|
out_chs, |
|
kernel_size=small_kernel, |
|
stride=self.stride, |
|
groups=self.groups, |
|
apply_act=False, |
|
) |
|
|
|
self.act = act_layer() if act_layer is not None else nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.reparam_conv is not None: |
|
out = self.reparam_conv(x) |
|
else: |
|
out = self.large_conv(x) |
|
if self.small_conv is not None: |
|
out = out + self.small_conv(x) |
|
out = self.act(out) |
|
return out |
|
|
|
def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to obtain re-parameterized kernel and bias. |
|
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing branches. |
|
""" |
|
eq_k, eq_b = self._fuse_bn(self.large_conv.conv, self.large_conv.bn) |
|
if hasattr(self, "small_conv"): |
|
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) |
|
eq_b += small_b |
|
eq_k += nn.functional.pad( |
|
small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 |
|
) |
|
return eq_k, eq_b |
|
|
|
def reparameterize(self) -> None: |
|
""" |
|
Following works like `RepVGG: Making VGG-style ConvNets Great Again` - |
|
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched |
|
architecture used at training time to obtain a plain CNN-like structure |
|
for inference. |
|
""" |
|
eq_k, eq_b = self.get_kernel_bias() |
|
self.reparam_conv = create_conv2d( |
|
self.in_chs, |
|
self.out_chs, |
|
kernel_size=self.kernel_size, |
|
stride=self.stride, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
|
|
self.reparam_conv.weight.data = eq_k |
|
self.reparam_conv.bias.data = eq_b |
|
self.__delattr__("large_conv") |
|
if hasattr(self, "small_conv"): |
|
self.__delattr__("small_conv") |
|
|
|
@staticmethod |
|
def _fuse_bn( |
|
conv: torch.Tensor, bn: nn.BatchNorm2d |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Method to fuse batchnorm layer with conv layer. |
|
|
|
Args: |
|
conv: Convolutional kernel weights. |
|
bn: Batchnorm 2d layer. |
|
|
|
Returns: |
|
Tuple of (kernel, bias) after fusing batchnorm. |
|
""" |
|
kernel = conv.weight |
|
running_mean = bn.running_mean |
|
running_var = bn.running_var |
|
gamma = bn.weight |
|
beta = bn.bias |
|
eps = bn.eps |
|
std = (running_var + eps).sqrt() |
|
t = (gamma / std).reshape(-1, 1, 1, 1) |
|
return kernel * t, beta - running_mean * gamma / std |
|
|
|
|
|
def convolutional_stem( |
|
in_chs: int, |
|
out_chs: int, |
|
act_layer: nn.Module = nn.GELU, |
|
inference_mode: bool = False |
|
) -> nn.Sequential: |
|
"""Build convolutional stem with MobileOne blocks. |
|
|
|
Args: |
|
in_chs: Number of input channels. |
|
out_chs: Number of output channels. |
|
inference_mode: Flag to instantiate model in inference mode. Default: ``False`` |
|
|
|
Returns: |
|
nn.Sequential object with stem elements. |
|
""" |
|
return nn.Sequential( |
|
MobileOneBlock( |
|
in_chs=in_chs, |
|
out_chs=out_chs, |
|
kernel_size=3, |
|
stride=2, |
|
act_layer=act_layer, |
|
inference_mode=inference_mode, |
|
), |
|
MobileOneBlock( |
|
in_chs=out_chs, |
|
out_chs=out_chs, |
|
kernel_size=3, |
|
stride=2, |
|
group_size=1, |
|
act_layer=act_layer, |
|
inference_mode=inference_mode, |
|
), |
|
MobileOneBlock( |
|
in_chs=out_chs, |
|
out_chs=out_chs, |
|
kernel_size=1, |
|
stride=1, |
|
act_layer=act_layer, |
|
inference_mode=inference_mode, |
|
), |
|
) |
|
|
|
|
|
class Attention(nn.Module): |
|
"""Multi-headed Self Attention module. |
|
|
|
Source modified from: |
|
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
|
""" |
|
fused_attn: torch.jit.Final[bool] |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
head_dim: int = 32, |
|
qkv_bias: bool = False, |
|
attn_drop: float = 0.0, |
|
proj_drop: float = 0.0, |
|
) -> None: |
|
"""Build MHSA module that can handle 3D or 4D input tensors. |
|
|
|
Args: |
|
dim: Number of embedding dimensions. |
|
head_dim: Number of hidden dimensions per head. Default: ``32`` |
|
qkv_bias: Use bias or not. Default: ``False`` |
|
attn_drop: Dropout rate for attention tensor. |
|
proj_drop: Dropout rate for projection tensor. |
|
""" |
|
super().__init__() |
|
assert dim % head_dim == 0, "dim should be divisible by head_dim" |
|
self.head_dim = head_dim |
|
self.num_heads = dim // head_dim |
|
self.scale = head_dim ** -0.5 |
|
self.fused_attn = use_fused_attn() |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
B, C, H, W = x.shape |
|
N = H * W |
|
x = x.flatten(2).transpose(-2, -1) |
|
qkv = ( |
|
self.qkv(x) |
|
.reshape(B, N, 3, self.num_heads, self.head_dim) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, v = qkv.unbind(0) |
|
|
|
if self.fused_attn: |
|
x = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, |
|
dropout_p=self.attn_drop.p if self.training else 0., |
|
) |
|
else: |
|
q = q * self.scale |
|
attn = q @ k.transpose(-2, -1) |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
x = attn @ v |
|
|
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
x = x.transpose(-2, -1).reshape(B, C, H, W) |
|
|
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
"""Convolutional patch embedding layer.""" |
|
|
|
def __init__( |
|
self, |
|
patch_size: int, |
|
stride: int, |
|
in_chs: int, |
|
embed_dim: int, |
|
act_layer: nn.Module = nn.GELU, |
|
lkc_use_act: bool = False, |
|
inference_mode: bool = False, |
|
) -> None: |
|
"""Build patch embedding layer. |
|
|
|
Args: |
|
patch_size: Patch size for embedding computation. |
|
stride: Stride for convolutional embedding layer. |
|
in_chs: Number of channels of input tensor. |
|
embed_dim: Number of embedding dimensions. |
|
inference_mode: Flag to instantiate model in inference mode. Default: ``False`` |
|
""" |
|
super().__init__() |
|
self.proj = nn.Sequential( |
|
ReparamLargeKernelConv( |
|
in_chs=in_chs, |
|
out_chs=embed_dim, |
|
kernel_size=patch_size, |
|
stride=stride, |
|
group_size=1, |
|
small_kernel=3, |
|
inference_mode=inference_mode, |
|
act_layer=act_layer if lkc_use_act else None, |
|
), |
|
MobileOneBlock( |
|
in_chs=embed_dim, |
|
out_chs=embed_dim, |
|
kernel_size=1, |
|
stride=1, |
|
act_layer=act_layer, |
|
inference_mode=inference_mode, |
|
) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
class LayerScale2d(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim, 1, 1)) |
|
|
|
def forward(self, x): |
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
class RepMixer(nn.Module): |
|
"""Reparameterizable token mixer. |
|
|
|
For more details, please refer to our paper: |
|
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_ |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
kernel_size=3, |
|
layer_scale_init_value=1e-5, |
|
inference_mode: bool = False, |
|
): |
|
"""Build RepMixer Module. |
|
|
|
Args: |
|
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. |
|
kernel_size: Kernel size for spatial mixing. Default: 3 |
|
layer_scale_init_value: Initial value for layer scale. Default: 1e-5 |
|
inference_mode: If True, instantiates model in inference mode. Default: ``False`` |
|
""" |
|
super().__init__() |
|
self.dim = dim |
|
self.kernel_size = kernel_size |
|
self.inference_mode = inference_mode |
|
|
|
if inference_mode: |
|
self.reparam_conv = nn.Conv2d( |
|
self.dim, |
|
self.dim, |
|
kernel_size=self.kernel_size, |
|
stride=1, |
|
padding=self.kernel_size // 2, |
|
groups=self.dim, |
|
bias=True, |
|
) |
|
else: |
|
self.reparam_conv = None |
|
self.norm = MobileOneBlock( |
|
dim, |
|
dim, |
|
kernel_size, |
|
group_size=1, |
|
use_act=False, |
|
use_scale_branch=False, |
|
num_conv_branches=0, |
|
) |
|
self.mixer = MobileOneBlock( |
|
dim, |
|
dim, |
|
kernel_size, |
|
group_size=1, |
|
use_act=False, |
|
) |
|
if layer_scale_init_value is not None: |
|
self.layer_scale = LayerScale2d(dim, layer_scale_init_value) |
|
else: |
|
self.layer_scale = nn.Identity |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.reparam_conv is not None: |
|
x = self.reparam_conv(x) |
|
else: |
|
x = x + self.layer_scale(self.mixer(x) - self.norm(x)) |
|
return x |
|
|
|
def reparameterize(self) -> None: |
|
"""Reparameterize mixer and norm into a single |
|
convolutional layer for efficient inference. |
|
""" |
|
if self.inference_mode: |
|
return |
|
|
|
self.mixer.reparameterize() |
|
self.norm.reparameterize() |
|
|
|
if isinstance(self.layer_scale, LayerScale2d): |
|
w = self.mixer.id_tensor + self.layer_scale.gamma.unsqueeze(-1) * ( |
|
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight |
|
) |
|
b = torch.squeeze(self.layer_scale.gamma) * ( |
|
self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias |
|
) |
|
else: |
|
w = ( |
|
self.mixer.id_tensor |
|
+ self.mixer.reparam_conv.weight |
|
- self.norm.reparam_conv.weight |
|
) |
|
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias |
|
|
|
self.reparam_conv = create_conv2d( |
|
self.dim, |
|
self.dim, |
|
kernel_size=self.kernel_size, |
|
stride=1, |
|
groups=self.dim, |
|
bias=True, |
|
) |
|
self.reparam_conv.weight.data = w |
|
self.reparam_conv.bias.data = b |
|
|
|
for name, para in self.named_parameters(): |
|
if 'reparam_conv' in name: |
|
continue |
|
para.detach_() |
|
self.__delattr__("mixer") |
|
self.__delattr__("norm") |
|
self.__delattr__("layer_scale") |
|
|
|
|
|
class ConvMlp(nn.Module): |
|
"""Convolutional FFN Module.""" |
|
|
|
def __init__( |
|
self, |
|
in_chs: int, |
|
hidden_channels: Optional[int] = None, |
|
out_chs: Optional[int] = None, |
|
act_layer: nn.Module = nn.GELU, |
|
drop: float = 0.0, |
|
) -> None: |
|
"""Build convolutional FFN module. |
|
|
|
Args: |
|
in_chs: Number of input channels. |
|
hidden_channels: Number of channels after expansion. Default: None |
|
out_chs: Number of output channels. Default: None |
|
act_layer: Activation layer. Default: ``GELU`` |
|
drop: Dropout rate. Default: ``0.0``. |
|
""" |
|
super().__init__() |
|
out_chs = out_chs or in_chs |
|
hidden_channels = hidden_channels or in_chs |
|
self.conv = ConvNormAct( |
|
in_chs, |
|
out_chs, |
|
kernel_size=7, |
|
groups=in_chs, |
|
apply_act=False, |
|
) |
|
self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) |
|
self.act = act_layer() |
|
self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) |
|
self.drop = nn.Dropout(drop) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m: nn.Module) -> None: |
|
if isinstance(m, nn.Conv2d): |
|
trunc_normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.conv(x) |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class RepConditionalPosEnc(nn.Module): |
|
"""Implementation of conditional positional encoding. |
|
|
|
For more details refer to paper: |
|
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_ |
|
|
|
In our implementation, we can reparameterize this module to eliminate a skip connection. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
spatial_shape: Union[int, Tuple[int, int]] = (7, 7), |
|
inference_mode=False, |
|
) -> None: |
|
"""Build reparameterizable conditional positional encoding |
|
|
|
Args: |
|
dim: Number of input channels. |
|
dim_out: Number of embedding dimensions. Default: 768 |
|
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) |
|
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` |
|
""" |
|
super(RepConditionalPosEnc, self).__init__() |
|
if isinstance(spatial_shape, int): |
|
spatial_shape = tuple([spatial_shape] * 2) |
|
assert isinstance(spatial_shape, Tuple), ( |
|
f'"spatial_shape" must by a sequence or int, ' |
|
f"get {type(spatial_shape)} instead." |
|
) |
|
assert len(spatial_shape) == 2, ( |
|
f'Length of "spatial_shape" should be 2, ' |
|
f"got {len(spatial_shape)} instead." |
|
) |
|
|
|
self.spatial_shape = spatial_shape |
|
self.dim = dim |
|
self.dim_out = dim_out or dim |
|
self.groups = dim |
|
|
|
if inference_mode: |
|
self.reparam_conv = nn.Conv2d( |
|
self.dim, |
|
self.dim_out, |
|
kernel_size=self.spatial_shape, |
|
stride=1, |
|
padding=spatial_shape[0] // 2, |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
else: |
|
self.reparam_conv = None |
|
self.pos_enc = nn.Conv2d( |
|
self.dim, |
|
self.dim_out, |
|
spatial_shape, |
|
1, |
|
int(spatial_shape[0] // 2), |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.reparam_conv is not None: |
|
x = self.reparam_conv(x) |
|
else: |
|
x = self.pos_enc(x) + x |
|
return x |
|
|
|
def reparameterize(self) -> None: |
|
|
|
input_dim = self.dim // self.groups |
|
kernel_value = torch.zeros( |
|
( |
|
self.dim, |
|
input_dim, |
|
self.spatial_shape[0], |
|
self.spatial_shape[1], |
|
), |
|
dtype=self.pos_enc.weight.dtype, |
|
device=self.pos_enc.weight.device, |
|
) |
|
for i in range(self.dim): |
|
kernel_value[ |
|
i, |
|
i % input_dim, |
|
self.spatial_shape[0] // 2, |
|
self.spatial_shape[1] // 2, |
|
] = 1 |
|
id_tensor = kernel_value |
|
|
|
|
|
w_final = id_tensor + self.pos_enc.weight |
|
b_final = self.pos_enc.bias |
|
|
|
|
|
self.reparam_conv = nn.Conv2d( |
|
self.dim, |
|
self.dim_out, |
|
kernel_size=self.spatial_shape, |
|
stride=1, |
|
padding=int(self.spatial_shape[0] // 2), |
|
groups=self.groups, |
|
bias=True, |
|
) |
|
self.reparam_conv.weight.data = w_final |
|
self.reparam_conv.bias.data = b_final |
|
|
|
for name, para in self.named_parameters(): |
|
if 'reparam_conv' in name: |
|
continue |
|
para.detach_() |
|
self.__delattr__("pos_enc") |
|
|
|
|
|
class RepMixerBlock(nn.Module): |
|
"""Implementation of Metaformer block with RepMixer as token mixer. |
|
|
|
For more details on Metaformer structure, please refer to: |
|
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
kernel_size: int = 3, |
|
mlp_ratio: float = 4.0, |
|
act_layer: nn.Module = nn.GELU, |
|
proj_drop: float = 0.0, |
|
drop_path: float = 0.0, |
|
layer_scale_init_value: float = 1e-5, |
|
inference_mode: bool = False, |
|
): |
|
"""Build RepMixer Block. |
|
|
|
Args: |
|
dim: Number of embedding dimensions. |
|
kernel_size: Kernel size for repmixer. Default: 3 |
|
mlp_ratio: MLP expansion ratio. Default: 4.0 |
|
act_layer: Activation layer. Default: ``nn.GELU`` |
|
proj_drop: Dropout rate. Default: 0.0 |
|
drop_path: Drop path rate. Default: 0.0 |
|
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 |
|
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` |
|
""" |
|
|
|
super().__init__() |
|
|
|
self.token_mixer = RepMixer( |
|
dim, |
|
kernel_size=kernel_size, |
|
layer_scale_init_value=layer_scale_init_value, |
|
inference_mode=inference_mode, |
|
) |
|
|
|
self.mlp = ConvMlp( |
|
in_chs=dim, |
|
hidden_channels=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
drop=proj_drop, |
|
) |
|
if layer_scale_init_value is not None: |
|
self.layer_scale = LayerScale2d(dim, layer_scale_init_value) |
|
else: |
|
self.layer_scale = nn.Identity() |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.token_mixer(x) |
|
x = x + self.drop_path(self.layer_scale(self.mlp(x))) |
|
return x |
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
"""Implementation of metaformer block with MHSA as token mixer. |
|
|
|
For more details on Metaformer structure, please refer to: |
|
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
mlp_ratio: float = 4.0, |
|
act_layer: nn.Module = nn.GELU, |
|
norm_layer: nn.Module = nn.BatchNorm2d, |
|
proj_drop: float = 0.0, |
|
drop_path: float = 0.0, |
|
layer_scale_init_value: float = 1e-5, |
|
): |
|
"""Build Attention Block. |
|
|
|
Args: |
|
dim: Number of embedding dimensions. |
|
mlp_ratio: MLP expansion ratio. Default: 4.0 |
|
act_layer: Activation layer. Default: ``nn.GELU`` |
|
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` |
|
proj_drop: Dropout rate. Default: 0.0 |
|
drop_path: Drop path rate. Default: 0.0 |
|
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 |
|
""" |
|
|
|
super().__init__() |
|
|
|
self.norm = norm_layer(dim) |
|
self.token_mixer = Attention(dim=dim) |
|
if layer_scale_init_value is not None: |
|
self.layer_scale_1 = LayerScale2d(dim, layer_scale_init_value) |
|
else: |
|
self.layer_scale_1 = nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
self.mlp = ConvMlp( |
|
in_chs=dim, |
|
hidden_channels=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
drop=proj_drop, |
|
) |
|
if layer_scale_init_value is not None: |
|
self.layer_scale_2 = LayerScale2d(dim, layer_scale_init_value) |
|
else: |
|
self.layer_scale_2 = nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.layer_scale_1(self.token_mixer(self.norm(x)))) |
|
x = x + self.drop_path2(self.layer_scale_2(self.mlp(x))) |
|
return x |
|
|
|
|
|
class FastVitStage(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int, |
|
depth: int, |
|
token_mixer_type: str, |
|
downsample: bool = True, |
|
down_patch_size: int = 7, |
|
down_stride: int = 2, |
|
pos_emb_layer: Optional[nn.Module] = None, |
|
kernel_size: int = 3, |
|
mlp_ratio: float = 4.0, |
|
act_layer: nn.Module = nn.GELU, |
|
norm_layer: nn.Module = nn.BatchNorm2d, |
|
proj_drop_rate: float = 0.0, |
|
drop_path_rate: float = 0.0, |
|
layer_scale_init_value: Optional[float] = 1e-5, |
|
lkc_use_act=False, |
|
inference_mode=False, |
|
): |
|
"""FastViT stage. |
|
|
|
Args: |
|
dim: Number of embedding dimensions. |
|
depth: Number of blocks in stage |
|
token_mixer_type: Token mixer type. |
|
kernel_size: Kernel size for repmixer. |
|
mlp_ratio: MLP expansion ratio. |
|
act_layer: Activation layer. |
|
norm_layer: Normalization layer. |
|
proj_drop_rate: Dropout rate. |
|
drop_path_rate: Drop path rate. |
|
layer_scale_init_value: Layer scale value at initialization. |
|
inference_mode: Flag to instantiate block in inference mode. |
|
""" |
|
super().__init__() |
|
self.grad_checkpointing = False |
|
|
|
if downsample: |
|
self.downsample = PatchEmbed( |
|
patch_size=down_patch_size, |
|
stride=down_stride, |
|
in_chs=dim, |
|
embed_dim=dim_out, |
|
act_layer=act_layer, |
|
lkc_use_act=lkc_use_act, |
|
inference_mode=inference_mode, |
|
) |
|
else: |
|
assert dim == dim_out |
|
self.downsample = nn.Identity() |
|
|
|
if pos_emb_layer is not None: |
|
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) |
|
else: |
|
self.pos_emb = nn.Identity() |
|
|
|
blocks = [] |
|
for block_idx in range(depth): |
|
if token_mixer_type == "repmixer": |
|
blocks.append(RepMixerBlock( |
|
dim_out, |
|
kernel_size=kernel_size, |
|
mlp_ratio=mlp_ratio, |
|
act_layer=act_layer, |
|
proj_drop=proj_drop_rate, |
|
drop_path=drop_path_rate[block_idx], |
|
layer_scale_init_value=layer_scale_init_value, |
|
inference_mode=inference_mode, |
|
)) |
|
elif token_mixer_type == "attention": |
|
blocks.append(AttentionBlock( |
|
dim_out, |
|
mlp_ratio=mlp_ratio, |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
proj_drop=proj_drop_rate, |
|
drop_path=drop_path_rate[block_idx], |
|
layer_scale_init_value=layer_scale_init_value, |
|
)) |
|
else: |
|
raise ValueError( |
|
"Token mixer type: {} not supported".format(token_mixer_type) |
|
) |
|
self.blocks = nn.Sequential(*blocks) |
|
|
|
def forward(self, x): |
|
x = self.downsample(x) |
|
x = self.pos_emb(x) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq(self.blocks, x) |
|
else: |
|
x = self.blocks(x) |
|
return x |
|
|
|
|
|
class FastVit(nn.Module): |
|
fork_feat: torch.jit.Final[bool] |
|
|
|
""" |
|
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_ |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chans: int = 3, |
|
layers: Tuple[int, ...] = (2, 2, 6, 2), |
|
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), |
|
embed_dims: Tuple[int, ...] = (64, 128, 256, 512), |
|
mlp_ratios: Tuple[float, ...] = (4,) * 4, |
|
downsamples: Tuple[bool, ...] = (False, True, True, True), |
|
repmixer_kernel_size: int = 3, |
|
num_classes: int = 1000, |
|
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, |
|
down_patch_size: int = 7, |
|
down_stride: int = 2, |
|
drop_rate: float = 0.0, |
|
proj_drop_rate: float = 0.0, |
|
drop_path_rate: float = 0.0, |
|
layer_scale_init_value: float = 1e-5, |
|
fork_feat: bool = False, |
|
cls_ratio: float = 2.0, |
|
global_pool: str = 'avg', |
|
norm_layer: nn.Module = nn.BatchNorm2d, |
|
act_layer: nn.Module = nn.GELU, |
|
lkc_use_act: bool = False, |
|
inference_mode: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.num_classes = 0 if fork_feat else num_classes |
|
self.fork_feat = fork_feat |
|
self.global_pool = global_pool |
|
self.feature_info = [] |
|
|
|
|
|
self.stem = convolutional_stem( |
|
in_chans, |
|
embed_dims[0], |
|
act_layer, |
|
inference_mode, |
|
) |
|
|
|
|
|
prev_dim = embed_dims[0] |
|
scale = 1 |
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] |
|
stages = [] |
|
for i in range(len(layers)): |
|
downsample = downsamples[i] or prev_dim != embed_dims[i] |
|
stage = FastVitStage( |
|
dim=prev_dim, |
|
dim_out=embed_dims[i], |
|
depth=layers[i], |
|
downsample=downsample, |
|
down_patch_size=down_patch_size, |
|
down_stride=down_stride, |
|
pos_emb_layer=pos_embs[i], |
|
token_mixer_type=token_mixers[i], |
|
kernel_size=repmixer_kernel_size, |
|
mlp_ratio=mlp_ratios[i], |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
proj_drop_rate=proj_drop_rate, |
|
drop_path_rate=dpr[i], |
|
layer_scale_init_value=layer_scale_init_value, |
|
lkc_use_act=lkc_use_act, |
|
inference_mode=inference_mode, |
|
) |
|
stages.append(stage) |
|
prev_dim = embed_dims[i] |
|
if downsample: |
|
scale *= 2 |
|
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] |
|
self.stages = nn.Sequential(*stages) |
|
self.num_features = prev_dim |
|
|
|
|
|
if self.fork_feat: |
|
|
|
|
|
|
|
self.out_indices = [0, 1, 2, 3] |
|
for i_emb, i_layer in enumerate(self.out_indices): |
|
if i_emb == 0 and os.environ.get("FORK_LAST3", None): |
|
"""For RetinaNet, `start_level=1`. The first norm layer will not used. |
|
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` |
|
""" |
|
layer = nn.Identity() |
|
else: |
|
layer = norm_layer(embed_dims[i_emb]) |
|
layer_name = f"norm{i_layer}" |
|
self.add_module(layer_name, layer) |
|
else: |
|
|
|
self.num_features = final_features = int(embed_dims[-1] * cls_ratio) |
|
self.final_conv = MobileOneBlock( |
|
in_chs=embed_dims[-1], |
|
out_chs=final_features, |
|
kernel_size=3, |
|
stride=1, |
|
group_size=1, |
|
inference_mode=inference_mode, |
|
use_se=True, |
|
act_layer=act_layer, |
|
num_conv_branches=1, |
|
) |
|
self.head = ClassifierHead( |
|
final_features, |
|
num_classes, |
|
pool_type=global_pool, |
|
drop_rate=drop_rate, |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m: nn.Module) -> None: |
|
"""Init. for classification""" |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return set() |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
return dict( |
|
stem=r'^stem', |
|
blocks=r'^stages\.(\d+)' if coarse else [ |
|
(r'^stages\.(\d+).downsample', (0,)), |
|
(r'^stages\.(\d+).pos_emb', (0,)), |
|
(r'^stages\.(\d+)\.\w+\.(\d+)', None), |
|
] |
|
) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
for s in self.stages: |
|
s.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self): |
|
return self.head.fc |
|
|
|
def reset_classifier(self, num_classes, global_pool=None): |
|
self.num_classes = num_classes |
|
self.head.reset(num_classes, global_pool) |
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = self.stem(x) |
|
outs = [] |
|
for idx, block in enumerate(self.stages): |
|
x = block(x) |
|
if self.fork_feat: |
|
if idx in self.out_indices: |
|
norm_layer = getattr(self, f"norm{idx}") |
|
x_out = norm_layer(x) |
|
outs.append(x_out) |
|
if self.fork_feat: |
|
|
|
return outs |
|
x = self.final_conv(x) |
|
return x |
|
|
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False): |
|
return self.head(x, pre_logits=True) if pre_logits else self.head(x) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.forward_features(x) |
|
if self.fork_feat: |
|
return x |
|
x = self.forward_head(x) |
|
return x |
|
|
|
|
|
def _cfg(url="", **kwargs): |
|
return { |
|
"url": url, |
|
"num_classes": 1000, |
|
"input_size": (3, 256, 256), |
|
"pool_size": (8, 8), |
|
"crop_pct": 0.9, |
|
"interpolation": "bicubic", |
|
"mean": IMAGENET_DEFAULT_MEAN, |
|
"std": IMAGENET_DEFAULT_STD, |
|
'first_conv': ('stem.0.conv_kxk.0.conv', 'stem.0.conv_scale.conv'), |
|
"classifier": "head.fc", |
|
**kwargs, |
|
} |
|
|
|
|
|
default_cfgs = generate_default_cfgs({ |
|
"fastvit_t8.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
"fastvit_t12.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
|
|
"fastvit_s12.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
"fastvit_sa12.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
"fastvit_sa24.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
"fastvit_sa36.apple_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
|
|
"fastvit_ma36.apple_in1k": _cfg( |
|
hf_hub_id='timm/', |
|
crop_pct=0.95 |
|
), |
|
|
|
"fastvit_t8.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
"fastvit_t12.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/'), |
|
|
|
"fastvit_s12.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/',), |
|
"fastvit_sa12.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/',), |
|
"fastvit_sa24.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/',), |
|
"fastvit_sa36.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/',), |
|
|
|
"fastvit_ma36.apple_dist_in1k": _cfg( |
|
hf_hub_id='timm/', |
|
crop_pct=0.95 |
|
), |
|
}) |
|
|
|
|
|
def _create_fastvit(variant, pretrained=False, **kwargs): |
|
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) |
|
model = build_model_with_cfg( |
|
FastVit, |
|
variant, |
|
pretrained, |
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), |
|
**kwargs |
|
) |
|
return model |
|
|
|
|
|
@register_model |
|
def fastvit_t8(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-T8 model variant.""" |
|
model_args = dict( |
|
layers=(2, 2, 4, 2), |
|
embed_dims=(48, 96, 192, 384), |
|
mlp_ratios=(3, 3, 3, 3), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") |
|
) |
|
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_t12(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-T12 model variant.""" |
|
model_args = dict( |
|
layers=(2, 2, 6, 2), |
|
embed_dims=(64, 128, 256, 512), |
|
mlp_ratios=(3, 3, 3, 3), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), |
|
) |
|
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_s12(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-S12 model variant.""" |
|
model_args = dict( |
|
layers=(2, 2, 6, 2), |
|
embed_dims=(64, 128, 256, 512), |
|
mlp_ratios=(4, 4, 4, 4), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), |
|
) |
|
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_sa12(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-SA12 model variant.""" |
|
model_args = dict( |
|
layers=(2, 2, 6, 2), |
|
embed_dims=(64, 128, 256, 512), |
|
mlp_ratios=(4, 4, 4, 4), |
|
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), |
|
) |
|
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_sa24(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-SA24 model variant.""" |
|
model_args = dict( |
|
layers=(4, 4, 12, 4), |
|
embed_dims=(64, 128, 256, 512), |
|
mlp_ratios=(4, 4, 4, 4), |
|
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), |
|
) |
|
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_sa36(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-SA36 model variant.""" |
|
model_args = dict( |
|
layers=(6, 6, 18, 6), |
|
embed_dims=(64, 128, 256, 512), |
|
mlp_ratios=(4, 4, 4, 4), |
|
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), |
|
) |
|
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|
|
|
|
@register_model |
|
def fastvit_ma36(pretrained=False, **kwargs): |
|
"""Instantiate FastViT-MA36 model variant.""" |
|
model_args = dict( |
|
layers=(6, 6, 18, 6), |
|
embed_dims=(76, 152, 304, 608), |
|
mlp_ratios=(4, 4, 4, 4), |
|
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), |
|
token_mixers=("repmixer", "repmixer", "repmixer", "attention") |
|
) |
|
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) |
|
|