import math import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_ from collections import OrderedDict def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 posemb_tok, posemb_grid = ( posemb[:, :num_extra_tokens], posemb[0, num_extra_tokens:], ) if grid_old_shape is None: gs_old_h = int(math.sqrt(len(posemb_grid))) gs_old_w = gs_old_h else: gs_old_h, gs_old_w = grid_old_shape gs_h, gs_w = grid_new_shape posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def checkpoint_filter_fn(state_dict, model): """convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} if "model" in state_dict: # For deit models state_dict = state_dict["model"] num_extra_tokens = 1 + ("dist_token" in state_dict.keys()) patch_size = model.patch_size image_size = model.patch_embed.image_size for k, v in state_dict.items(): if k == "pos_embed" and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, None, (image_size[0] // patch_size, image_size[1] // patch_size), num_extra_tokens, ) out_dict[k] = v return out_dict def load_params(ckpt_file, device): # params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}') # new_params = [] # for key, value in params.items(): # new_params.append(("module."+key if has_module else key, value)) # return OrderedDict(new_params) params = torch.load(ckpt_file, map_location=device) new_params = [] for key, value in params.items(): new_params.append((key, value)) return OrderedDict(new_params)