pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
4.6 kB
""" ConvMixer
"""
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
__all__ = ['ConvMixer']
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
depth,
kernel_size=9,
patch_size=7,
in_chans=3,
num_classes=1000,
global_pool='avg',
drop_rate=0.,
act_layer=nn.GELU,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.grad_checkpointing = False
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
act_layer(),
nn.BatchNorm2d(dim)
)
self.blocks = nn.Sequential(
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
act_layer(),
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
act_layer(),
nn.BatchNorm2d(dim)
) for i in range(depth)]
)
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.pooling(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_convmixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for ConvMixer models.')
return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .96, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
'first_conv': 'stem.0',
**kwargs
}
default_cfgs = generate_default_cfgs({
'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'),
'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'),
'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/')
})
@register_model
def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
@register_model
def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
return _create_convmixer('convmixer_768_32', pretrained, **model_args)
@register_model
def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)