Spaces:
Sleeping
Sleeping
File size: 2,784 Bytes
3d85088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from collections import OrderedDict
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
posemb_tok, posemb_grid = (
posemb[:, :num_extra_tokens],
posemb[0, num_extra_tokens:],
)
if grid_old_shape is None:
gs_old_h = int(math.sqrt(len(posemb_grid)))
gs_old_w = gs_old_h
else:
gs_old_h, gs_old_w = grid_old_shape
gs_h, gs_w = grid_new_shape
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def checkpoint_filter_fn(state_dict, model):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if "model" in state_dict:
# For deit models
state_dict = state_dict["model"]
num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
patch_size = model.patch_size
image_size = model.patch_embed.image_size
for k, v in state_dict.items():
if k == "pos_embed" and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v,
None,
(image_size[0] // patch_size, image_size[1] // patch_size),
num_extra_tokens,
)
out_dict[k] = v
return out_dict
def load_params(ckpt_file, device):
# params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
# new_params = []
# for key, value in params.items():
# new_params.append(("module."+key if has_module else key, value))
# return OrderedDict(new_params)
params = torch.load(ckpt_file, map_location=device)
new_params = []
for key, value in params.items():
new_params.append((key, value))
return OrderedDict(new_params)
|