Spaces:
Build error
Build error
''' | |
Implementation of ViTSTR based on timm VisionTransformer. | |
TODO: | |
1) distilled deit backbone | |
2) base deit backbone | |
Copyright 2021 Rowel Atienza | |
''' | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import logging | |
import torch.utils.model_zoo as model_zoo | |
from copy import deepcopy | |
from functools import partial | |
from timm.models.vision_transformer import VisionTransformer, _cfg | |
from timm.models.registry import register_model | |
from timm.models import create_model | |
_logger = logging.getLogger(__name__) | |
__all__ = [ | |
'vitstr_tiny_patch16_224', | |
'vitstr_small_patch16_224', | |
'vitstr_base_patch16_224', | |
#'vitstr_tiny_distilled_patch16_224', | |
#'vitstr_small_distilled_patch16_224', | |
#'vitstr_base_distilled_patch16_224', | |
] | |
def create_vitstr(num_tokens, model=None, checkpoint_path=''): | |
vitstr = create_model( | |
model, | |
pretrained=True, | |
num_classes=num_tokens, | |
checkpoint_path=checkpoint_path) | |
# might need to run to get zero init head for transfer learning | |
vitstr.reset_classifier(num_classes=num_tokens) | |
return vitstr | |
class ViTSTR(VisionTransformer): | |
''' | |
ViTSTR is basically a ViT that uses DeiT weights. | |
Modified head to support a sequence of characters prediction for STR. | |
''' | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def reset_classifier(self, num_classes): | |
self.num_classes = num_classes | |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
def patch_embed_func(self): | |
return self.patch_embed | |
def forward_features(self, x): | |
B = x.shape[0] | |
# print("prevx shape: ", x.shape) ### (1, 224, 224) | |
x = self.patch_embed(x) | |
# print("new x shape: ", x.shape) ### (1, 196, 768) | |
# patchsize is 16X16 so there are 14X14 grids=196. | |
# 768 - embedding size | |
# self.cls_token shape: torch.Size([1, 1, 768]) | |
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
x = torch.cat((cls_tokens, x), dim=1) | |
# self.pos_embed shape: torch.Size([1, 197, 768])] | |
x = x + self.pos_embed | |
# + self.pos_embed shape: torch.Size([1, 197, 768]) | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
# blocks shape: torch.Size([1, 197, 768]) ALLL | |
x = self.norm(x) | |
# norm shape: torch.Size([1, 197, 768]) | |
return x | |
def forward(self, x, seqlen=25): | |
x = self.forward_features(x) | |
x = x[:, :seqlen] | |
# seqlen shape: torch.Size([1, 25, 768]) | |
# batch, seqlen, embsize | |
b, s, e = x.size() | |
x = x.reshape(b*s, e) | |
# reshaped shape: torch.Size([25, 768]) | |
x = self.head(x).view(b, s, self.num_classes) | |
return x | |
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=1, filter_fn=None, strict=True): | |
''' | |
Loads a pretrained checkpoint | |
From an older version of timm | |
''' | |
if cfg is None: | |
cfg = getattr(model, 'default_cfg') | |
if cfg is None or 'url' not in cfg or not cfg['url']: | |
_logger.warning("Pretrained model URL is invalid, using random initialization.") | |
return | |
state_dict = model_zoo.load_url(cfg['url'], progress=True, map_location='cpu') | |
if "model" in state_dict.keys(): | |
state_dict = state_dict["model"] | |
if filter_fn is not None: | |
state_dict = filter_fn(state_dict) | |
if in_chans == 1: | |
conv1_name = cfg['first_conv'] | |
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) | |
key = conv1_name + '.weight' | |
if key in state_dict.keys(): | |
_logger.info('(%s) key found in state_dict' % key) | |
conv1_weight = state_dict[conv1_name + '.weight'] | |
else: | |
_logger.info('(%s) key NOT found in state_dict' % key) | |
return | |
# Some weights are in torch.half, ensure it's float for sum on CPU | |
conv1_type = conv1_weight.dtype | |
conv1_weight = conv1_weight.float() | |
O, I, J, K = conv1_weight.shape | |
if I > 3: | |
assert conv1_weight.shape[1] % 3 == 0 | |
# For models with space2depth stems | |
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) | |
conv1_weight = conv1_weight.sum(dim=2, keepdim=False) | |
else: | |
conv1_weight = conv1_weight.sum(dim=1, keepdim=True) | |
conv1_weight = conv1_weight.to(conv1_type) | |
state_dict[conv1_name + '.weight'] = conv1_weight | |
classifier_name = cfg['classifier'] | |
if num_classes == 1000 and cfg['num_classes'] == 1001: | |
# special case for imagenet trained models with extra background class in pretrained weights | |
classifier_weight = state_dict[classifier_name + '.weight'] | |
state_dict[classifier_name + '.weight'] = classifier_weight[1:] | |
classifier_bias = state_dict[classifier_name + '.bias'] | |
state_dict[classifier_name + '.bias'] = classifier_bias[1:] | |
elif num_classes != cfg['num_classes']: | |
# completely discard fully connected for all other differences between pretrained and created model | |
del state_dict[classifier_name + '.weight'] | |
del state_dict[classifier_name + '.bias'] | |
strict = False | |
print("Loading pre-trained vision transformer weights from %s ..." % cfg['url']) | |
model.load_state_dict(state_dict, strict=strict) | |
def _conv_filter(state_dict, patch_size=16): | |
""" convert patch embedding weight from manual patchify + linear proj to conv""" | |
out_dict = {} | |
for k, v in state_dict.items(): | |
if 'patch_embed.proj.weight' in k: | |
v = v.reshape((v.shape[0], 3, patch_size, patch_size)) | |
out_dict[k] = v | |
return out_dict | |
def vitstr_tiny_patch16_224(pretrained=False, **kwargs): | |
kwargs['in_chans'] = 1 | |
model = ViTSTR( | |
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) | |
model.default_cfg = _cfg( | |
#url='https://github.com/roatienza/public/releases/download/v0.1-deit-tiny/deit_tiny_patch16_224-a1311bcf.pth' | |
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth' | |
) | |
if pretrained: | |
load_pretrained( | |
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
return model | |
def vitstr_small_patch16_224(pretrained=False, **kwargs): | |
kwargs['in_chans'] = 1 | |
model = ViTSTR( | |
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) | |
model.default_cfg = _cfg( | |
#url="https://github.com/roatienza/public/releases/download/v0.1-deit-small/deit_small_patch16_224-cd65a155.pth" | |
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" | |
) | |
if pretrained: | |
load_pretrained( | |
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
return model | |
def vitstr_base_patch16_224(pretrained=False, **kwargs): | |
kwargs['in_chans'] = 1 | |
model = ViTSTR( | |
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) | |
model.default_cfg = _cfg( | |
#url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth' | |
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth' | |
) | |
if pretrained: | |
load_pretrained( | |
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
return model | |
# below is work in progress | |
def vitstr_tiny_distilled_patch16_224(pretrained=False, **kwargs): | |
kwargs['in_chans'] = 1 | |
#kwargs['distilled'] = True | |
model = ViTSTR( | |
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs) | |
model.default_cfg = _cfg( | |
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth' | |
) | |
if pretrained: | |
load_pretrained( | |
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
return model | |
def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs): | |
kwargs['in_chans'] = 1 | |
kwargs['distilled'] = True | |
model = ViTSTR( | |
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs) | |
model.default_cfg = _cfg( | |
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth" | |
) | |
if pretrained: | |
load_pretrained( | |
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter) | |
return model | |