Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
from functools import partial | |
from typing import Tuple | |
import timm.models.vision_transformer | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
import torch | |
import torch.nn as nn | |
import math | |
from mmocr.registry import MODELS | |
class VisionTransformer(timm.models.vision_transformer.VisionTransformer): | |
""" Vision Transformer. | |
Args: | |
global_pool (bool): If True, apply global pooling to the output | |
of the last stage. Default: False. | |
patch_size (int): Patch token size. Default: 8. | |
img_size (tuple[int]): Input image size. Default: (32, 128). | |
embed_dim (int): Number of linear projection output channels. | |
Default: 192. | |
depth (int): Number of blocks. Default: 12. | |
num_heads (int): Number of attention heads. Default: 3. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
Default: 4. | |
qkv_bias (bool): If True, add a learnable bias to query, key, | |
value. Default: True. | |
norm_layer (nn.Module): Normalization layer. Default: | |
partial(nn.LayerNorm, eps=1e-6). | |
pretrained (str): Path to pre-trained checkpoint. Default: None. | |
""" | |
def __init__(self, | |
global_pool: bool = False, | |
patch_size: int = 8, | |
img_size: Tuple[int, int] = (32, 128), | |
embed_dim: int = 192, | |
depth: int = 12, | |
num_heads: int = 3, | |
mlp_ratio: int = 4., | |
qkv_bias: bool = True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
pretrained: bool = None, | |
**kwargs): | |
super(VisionTransformer, self).__init__( | |
patch_size=patch_size, | |
img_size=img_size, | |
embed_dim=embed_dim, | |
depth=depth, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
**kwargs) | |
self.global_pool = global_pool | |
if self.global_pool: | |
norm_layer = kwargs['norm_layer'] | |
embed_dim = kwargs['embed_dim'] | |
self.fc_norm = norm_layer(embed_dim) | |
del self.norm # remove the original norm | |
self.reset_classifier(0) | |
if pretrained: | |
checkpoint = torch.load(pretrained, map_location='cpu') | |
print("Load pre-trained checkpoint from: %s" % pretrained) | |
checkpoint_model = checkpoint['model'] | |
state_dict = self.state_dict() | |
for k in ['head.weight', 'head.bias']: | |
if k in checkpoint_model and checkpoint_model[ | |
k].shape != state_dict[k].shape: | |
print(f"Removing key {k} from pretrained checkpoint") | |
del checkpoint_model[k] | |
# remove key with decoder | |
for k in list(checkpoint_model.keys()): | |
if 'decoder' in k: | |
del checkpoint_model[k] | |
msg = self.load_state_dict(checkpoint_model, strict=False) | |
print(msg) | |
def forward_features(self, x: torch.Tensor): | |
B = x.shape[0] | |
x = self.patch_embed(x) | |
cls_tokens = self.cls_token.expand( | |
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
x = torch.cat((cls_tokens, x), dim=1) | |
x = x + self.pos_embed | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x | |
def forward(self, x): | |
return self.forward_features(x) | |
class _LoRA_qkv_timm(nn.Module): | |
"""LoRA layer for query and value projection in Vision Transformer of timm. | |
Args: | |
qkv (nn.Module): qkv projection layer in Vision Transformer of timm. | |
linear_a_q (nn.Module): Linear layer for query projection. | |
linear_b_q (nn.Module): Linear layer for query projection. | |
linear_a_v (nn.Module): Linear layer for value projection. | |
linear_b_v (nn.Module): Linear layer for value projection. | |
""" | |
def __init__( | |
self, | |
qkv: nn.Module, | |
linear_a_q: nn.Module, | |
linear_b_q: nn.Module, | |
linear_a_v: nn.Module, | |
linear_b_v: nn.Module, | |
): | |
super().__init__() | |
self.qkv = qkv | |
self.linear_a_q = linear_a_q | |
self.linear_b_q = linear_b_q | |
self.linear_a_v = linear_a_v | |
self.linear_b_v = linear_b_v | |
self.dim = qkv.in_features | |
def forward(self, x): | |
qkv = self.qkv(x) # B, N, 3*dim | |
new_q = self.linear_b_q(self.linear_a_q(x)) | |
new_v = self.linear_b_v(self.linear_a_v(x)) | |
qkv[:, :, :self.dim] += new_q | |
qkv[:, :, -self.dim:] += new_v | |
return qkv | |
class VisionTransformer_LoRA(nn.Module): | |
"""Vision Transformer with LoRA. For each block, we add a LoRA layer for | |
the linear projection of query and value. | |
Args: | |
vit_config (dict): Config dict for VisionTransformer. | |
rank (int): Rank of LoRA layer. Default: 4. | |
lora_layers (int): Stages to add LoRA layer. Defaults None means | |
add LoRA layer to all stages. | |
pretrained_lora (str): Path to pre-trained checkpoint of LoRA layer. | |
""" | |
def __init__(self, | |
vit_config: dict, | |
rank: int = 4, | |
lora_layers: int = None, | |
pretrained_lora: str = None): | |
super(VisionTransformer_LoRA, self).__init__() | |
self.vit = MODELS.build(vit_config) | |
assert rank > 0 | |
if lora_layers: | |
self.lora_layers = lora_layers | |
else: | |
self.lora_layers = list(range(len(self.vit.blocks))) | |
# creat list of LoRA layers | |
self.query_As = nn.Sequential() # matrix A for query linear projection | |
self.query_Bs = nn.Sequential() | |
self.value_As = nn.Sequential() # matrix B for value linear projection | |
self.value_Bs = nn.Sequential() | |
# freeze the original vit | |
for param in self.vit.parameters(): | |
param.requires_grad = False | |
# compose LoRA layers | |
for block_idx, block in enumerate(self.vit.blocks): | |
if block_idx not in self.lora_layers: | |
continue | |
# create LoRA layer | |
w_qkv_linear = block.attn.qkv | |
self.dim = w_qkv_linear.in_features | |
w_a_linear_q = nn.Linear(self.dim, rank, bias=False) | |
w_b_linear_q = nn.Linear(rank, self.dim, bias=False) | |
w_a_linear_v = nn.Linear(self.dim, rank, bias=False) | |
w_b_linear_v = nn.Linear(rank, self.dim, bias=False) | |
self.query_As.append(w_a_linear_q) | |
self.query_Bs.append(w_b_linear_q) | |
self.value_As.append(w_a_linear_v) | |
self.value_Bs.append(w_b_linear_v) | |
# replace the original qkv layer with LoRA layer | |
block.attn.qkv = _LoRA_qkv_timm( | |
w_qkv_linear, | |
w_a_linear_q, | |
w_b_linear_q, | |
w_a_linear_v, | |
w_b_linear_v, | |
) | |
self._init_lora() | |
if pretrained_lora is not None: | |
self._load_lora(pretrained_lora) | |
def _init_lora(self): | |
"""Initialize the LoRA layers to be identity mapping.""" | |
for query_A, query_B, value_A, value_B in zip(self.query_As, | |
self.query_Bs, | |
self.value_As, | |
self.value_Bs): | |
nn.init.kaiming_uniform_(query_A.weight, a=math.sqrt(5)) | |
nn.init.kaiming_uniform_(value_A.weight, a=math.sqrt(5)) | |
nn.init.zeros_(query_B.weight) | |
nn.init.zeros_(value_B.weight) | |
def _load_lora(self, checkpoint_lora: str): | |
"""Load pre-trained LoRA checkpoint. | |
Args: | |
checkpoint_lora (str): Path to pre-trained LoRA checkpoint. | |
""" | |
assert checkpoint_lora.endswith(".safetensors") | |
with safe_open(checkpoint_lora, framework="pt") as f: | |
for i, q_A, q_B, v_A, v_B in zip( | |
range(len(self.query_As)), | |
self.query_As, | |
self.query_Bs, | |
self.value_As, | |
self.value_Bs, | |
): | |
q_A.weight = nn.Parameter(f.get_tensor(f"q_a_{i:03d}")) | |
q_B.weight = nn.Parameter(f.get_tensor(f"q_b_{i:03d}")) | |
v_A.weight = nn.Parameter(f.get_tensor(f"v_a_{i:03d}")) | |
v_B.weight = nn.Parameter(f.get_tensor(f"v_b_{i:03d}")) | |
def forward(self, x): | |
x = self.vit(x) | |
return x | |
def extract_lora_from_vit(checkpoint_path: str, | |
save_path: str, | |
ckpt_key: str = None): | |
"""Given a checkpoint of VisionTransformer_LoRA, extract the LoRA weights | |
and save them to a new checkpoint. | |
Args: | |
checkpoint_path (str): Path to checkpoint of VisionTransformer_LoRA. | |
ckpt_key (str): Key of model in the checkpoint. | |
save_path (str): Path to save the extracted LoRA checkpoint. | |
""" | |
assert save_path.endswith(".safetensors") | |
ckpt = torch.load(checkpoint_path, map_location="cpu") | |
# travel throung the ckpt to find the LoRA layers | |
query_As = [] | |
query_Bs = [] | |
value_As = [] | |
value_Bs = [] | |
ckpt = ckpt if ckpt_key is None else ckpt[ckpt_key] | |
for k, v in ckpt.items(): | |
if k.startswith("query_As"): | |
query_As.append(v) | |
elif k.startswith("query_Bs"): | |
query_Bs.append(v) | |
elif k.startswith("value_As"): | |
value_As.append(v) | |
elif k.startswith("value_Bs"): | |
value_Bs.append(v) | |
# save the LoRA layers to a new checkpoint | |
ckpt_dict = {} | |
for i in range(len(query_As)): | |
ckpt_dict[f"q_a_{i:03d}"] = query_As[i] | |
ckpt_dict[f"q_b_{i:03d}"] = query_Bs[i] | |
ckpt_dict[f"v_a_{i:03d}"] = value_As[i] | |
ckpt_dict[f"v_b_{i:03d}"] = value_Bs[i] | |
save_file(ckpt_dict, save_path) | |