duongttr's picture
Update new app
3d85088
raw
history blame
2.78 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from collections import OrderedDict
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
posemb_tok, posemb_grid = (
posemb[:, :num_extra_tokens],
posemb[0, num_extra_tokens:],
)
if grid_old_shape is None:
gs_old_h = int(math.sqrt(len(posemb_grid)))
gs_old_w = gs_old_h
else:
gs_old_h, gs_old_w = grid_old_shape
gs_h, gs_w = grid_new_shape
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def checkpoint_filter_fn(state_dict, model):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if "model" in state_dict:
# For deit models
state_dict = state_dict["model"]
num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
patch_size = model.patch_size
image_size = model.patch_embed.image_size
for k, v in state_dict.items():
if k == "pos_embed" and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
None,
(image_size[0] // patch_size, image_size[1] // patch_size),
num_extra_tokens,
)
out_dict[k] = v
return out_dict
def load_params(ckpt_file, device):
# params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
# new_params = []
# for key, value in params.items():
# new_params.append(("module."+key if has_module else key, value))
# return OrderedDict(new_params)
params = torch.load(ckpt_file, map_location=device)
new_params = []
for key, value in params.items():
new_params.append((key, value))
return OrderedDict(new_params)