strexp / modules /vitstr.py
markytools's picture
added strexp
d61b9c7
raw
history blame
9.01 kB
'''
Implementation of ViTSTR based on timm VisionTransformer.
TODO:
1) distilled deit backbone
2) base deit backbone
Copyright 2021 Rowel Atienza
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import logging
import torch.utils.model_zoo as model_zoo
from copy import deepcopy
from functools import partial
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models import create_model
_logger = logging.getLogger(__name__)
__all__ = [
'vitstr_tiny_patch16_224',
'vitstr_small_patch16_224',
'vitstr_base_patch16_224',
#'vitstr_tiny_distilled_patch16_224',
#'vitstr_small_distilled_patch16_224',
#'vitstr_base_distilled_patch16_224',
]
def create_vitstr(num_tokens, model=None, checkpoint_path=''):
vitstr = create_model(
model,
pretrained=True,
num_classes=num_tokens,
checkpoint_path=checkpoint_path)
# might need to run to get zero init head for transfer learning
vitstr.reset_classifier(num_classes=num_tokens)
return vitstr
class ViTSTR(VisionTransformer):
'''
ViTSTR is basically a ViT that uses DeiT weights.
Modified head to support a sequence of characters prediction for STR.
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def patch_embed_func(self):
return self.patch_embed
def forward_features(self, x):
B = x.shape[0]
# print("prevx shape: ", x.shape) ### (1, 224, 224)
x = self.patch_embed(x)
# print("new x shape: ", x.shape) ### (1, 196, 768)
# patchsize is 16X16 so there are 14X14 grids=196.
# 768 - embedding size
# self.cls_token shape: torch.Size([1, 1, 768])
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
# self.pos_embed shape: torch.Size([1, 197, 768])]
x = x + self.pos_embed
# + self.pos_embed shape: torch.Size([1, 197, 768])
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
# blocks shape: torch.Size([1, 197, 768]) ALLL
x = self.norm(x)
# norm shape: torch.Size([1, 197, 768])
return x
def forward(self, x, seqlen=25):
x = self.forward_features(x)
x = x[:, :seqlen]
# seqlen shape: torch.Size([1, 25, 768])
# batch, seqlen, embsize
b, s, e = x.size()
x = x.reshape(b*s, e)
# reshaped shape: torch.Size([25, 768])
x = self.head(x).view(b, s, self.num_classes)
return x
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=1, filter_fn=None, strict=True):
'''
Loads a pretrained checkpoint
From an older version of timm
'''
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL is invalid, using random initialization.")
return
state_dict = model_zoo.load_url(cfg['url'], progress=True, map_location='cpu')
if "model" in state_dict.keys():
state_dict = state_dict["model"]
if filter_fn is not None:
state_dict = filter_fn(state_dict)
if in_chans == 1:
conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
key = conv1_name + '.weight'
if key in state_dict.keys():
_logger.info('(%s) key found in state_dict' % key)
conv1_weight = state_dict[conv1_name + '.weight']
else:
_logger.info('(%s) key NOT found in state_dict' % key)
return
# Some weights are in torch.half, ensure it's float for sum on CPU
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != cfg['num_classes']:
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
print("Loading pre-trained vision transformer weights from %s ..." % cfg['url'])
model.load_state_dict(state_dict, strict=strict)
def _conv_filter(state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
@register_model
def vitstr_tiny_patch16_224(pretrained=False, **kwargs):
kwargs['in_chans'] = 1
model = ViTSTR(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs)
model.default_cfg = _cfg(
#url='https://github.com/roatienza/public/releases/download/v0.1-deit-tiny/deit_tiny_patch16_224-a1311bcf.pth'
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'
)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
return model
@register_model
def vitstr_small_patch16_224(pretrained=False, **kwargs):
kwargs['in_chans'] = 1
model = ViTSTR(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs)
model.default_cfg = _cfg(
#url="https://github.com/roatienza/public/releases/download/v0.1-deit-small/deit_small_patch16_224-cd65a155.pth"
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
return model
@register_model
def vitstr_base_patch16_224(pretrained=False, **kwargs):
kwargs['in_chans'] = 1
model = ViTSTR(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
model.default_cfg = _cfg(
#url='https://github.com/roatienza/public/releases/download/v0.1-deit-base/deit_base_patch16_224-b5f2ef4d.pth'
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'
)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
return model
# below is work in progress
@register_model
def vitstr_tiny_distilled_patch16_224(pretrained=False, **kwargs):
kwargs['in_chans'] = 1
#kwargs['distilled'] = True
model = ViTSTR(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, **kwargs)
model.default_cfg = _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'
)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
return model
@register_model
def vitstr_small_distilled_patch16_224(pretrained=False, **kwargs):
kwargs['in_chans'] = 1
kwargs['distilled'] = True
model = ViTSTR(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, **kwargs)
model.default_cfg = _cfg(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth"
)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 1), filter_fn=_conv_filter)
return model