Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from timm.models.layers import trunc_normal_ | |
from collections import OrderedDict | |
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): | |
# Rescale the grid of position embeddings when loading from state_dict. Adapted from | |
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 | |
posemb_tok, posemb_grid = ( | |
posemb[:, :num_extra_tokens], | |
posemb[0, num_extra_tokens:], | |
) | |
if grid_old_shape is None: | |
gs_old_h = int(math.sqrt(len(posemb_grid))) | |
gs_old_w = gs_old_h | |
else: | |
gs_old_h, gs_old_w = grid_old_shape | |
gs_h, gs_w = grid_new_shape | |
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) | |
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") | |
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |
return posemb | |
def init_weights(m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def checkpoint_filter_fn(state_dict, model): | |
"""convert patch embedding weight from manual patchify + linear proj to conv""" | |
out_dict = {} | |
if "model" in state_dict: | |
# For deit models | |
state_dict = state_dict["model"] | |
num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) | |
patch_size = model.patch_size | |
image_size = model.patch_embed.image_size | |
for k, v in state_dict.items(): | |
if k == "pos_embed" and v.shape != model.pos_embed.shape: | |
# To resize pos embedding when using model at different size from pretrained weights | |
v = resize_pos_embed( | |
v, | |
None, | |
(image_size[0] // patch_size, image_size[1] // patch_size), | |
num_extra_tokens, | |
) | |
out_dict[k] = v | |
return out_dict | |
def load_params(ckpt_file, device): | |
# params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}') | |
# new_params = [] | |
# for key, value in params.items(): | |
# new_params.append(("module."+key if has_module else key, value)) | |
# return OrderedDict(new_params) | |
params = torch.load(ckpt_file, map_location=device) | |
new_params = [] | |
for key, value in params.items(): | |
new_params.append((key, value)) | |
return OrderedDict(new_params) | |