import torch from torch import nn from copy import deepcopy from .base import FM_to_MD_Util from utils.common.log import logger from utils.dl.common.model import set_module, get_module, get_super_module class FM_to_MD_ViT_Util(FM_to_MD_Util): def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: fm_vit = deepcopy(fm) def _f(n): return int(n // reducing_width_ratio) # def _rand_indexes(n): # return torch.randperm(n)[0: int(n // reducing_width_ratio)] def l1_max_indexes(p: torch.Tensor, dim=0): assert dim in [0, 1] assert p.dim() in [1, 2, 4] if dim == 1: p = p.T p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) n = p.size(0) return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] # first_attn = True for block_i, block in enumerate(fm_vit.blocks): qkv = block.attn.qkv new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), qkv.bias is not None, qkv.weight.device) indexes = l1_max_indexes(qkv.weight.data, 0) new_qkv.weight.data.copy_(qkv.weight.data[indexes]) if qkv.bias is not None: new_qkv.bias.data.copy_(qkv.bias.data[indexes]) set_module(fm_vit, f'blocks.{block_i}.attn.qkv', new_qkv) proj = block.attn.proj new_proj = nn.Linear(_f(proj.in_features), proj.out_features, proj.bias is not None, proj.weight.device) new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) if proj.bias is not None: new_proj.bias.data.copy_(proj.bias.data) set_module(fm_vit, f'blocks.{block_i}.attn.proj', new_proj) fc1 = block.mlp.fc1 new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), fc1.bias is not None, fc1.weight.device) indexes = l1_max_indexes(fc1.weight.data, 0) new_fc1.weight.data.copy_(fc1.weight.data[indexes]) if fc1.bias is not None: new_fc1.bias.data.copy_(fc1.bias.data[indexes]) set_module(fm_vit, f'blocks.{block_i}.mlp.fc1', new_fc1) fc2 = block.mlp.fc2 new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, fc2.bias is not None, fc2.weight.device) new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) if fc2.bias is not None: new_fc2.bias.data.copy_(fc2.bias.data) set_module(fm_vit, f'blocks.{block_i}.mlp.fc2', new_fc2) # reduce dim_embedding # if name.endswith('patch_embed.proj'): # continue # new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride, # module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, # module.weight.device) # rand_indexes = l1_max_indexes(module.weight.data) # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) # if new_layer.bias is not None: # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) # fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes] # fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes] # elif isinstance(module, nn.Linear): # if 'head' in name: # continue # new_layer = nn.Linear(_f(module.in_features), module.out_features, # module.bias is not None, module.weight.device) # new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)]) # if new_layer.bias is not None: # new_layer.bias.data.copy_(module.bias.data) # else: # first_attn = False # if first_attn: # first_attn = False # new_layer = nn.Linear(module.in_features, _f(module.out_features), # module.bias is not None, module.weight.device) # rand_indexes = l1_max_indexes(module.weight.data) # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) # if new_layer.bias is not None: # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) # else: # new_layer = nn.Linear(_f(module.in_features), _f(module.out_features), # module.bias is not None, module.weight.device) # rand_indexes = l1_max_indexes(module.weight.data) # new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)]) # if new_layer.bias is not None: # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) # elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'): # new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device) # rand_indexes = l1_max_indexes(module.weight.data) # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) # else: # continue # original_layer_str = str(module) # set_module(fm_vit, name, new_layer) # logger.debug(f'set_module, {name}, {new_layer}') # logger.debug(f'slim {name} from {original_layer_str} to {new_layer}') return fm_vit