import torch | |
from torch import nn | |
from copy import deepcopy | |
from .base import FM_to_MD_Util | |
from utils.common.log import logger | |
from utils.dl.common.model import set_module, get_module, get_super_module | |
class FM_to_MD_ViT_Util(FM_to_MD_Util): | |
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: | |
fm_vit = deepcopy(fm) | |
def _f(n): | |
return int(n // reducing_width_ratio) | |
# def _rand_indexes(n): | |
# return torch.randperm(n)[0: int(n // reducing_width_ratio)] | |
def l1_max_indexes(p: torch.Tensor, dim=0): | |
assert dim in [0, 1] | |
assert p.dim() in [1, 2, 4] | |
if dim == 1: | |
p = p.T | |
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
n = p.size(0) | |
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] | |
# first_attn = True | |
for block_i, block in enumerate(fm_vit.blocks): | |
qkv = block.attn.qkv | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
set_module(fm_vit, f'blocks.{block_i}.attn.qkv', new_qkv) | |
proj = block.attn.proj | |
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
proj.bias is not None, proj.weight.device) | |
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
if proj.bias is not None: | |
new_proj.bias.data.copy_(proj.bias.data) | |
set_module(fm_vit, f'blocks.{block_i}.attn.proj', new_proj) | |
fc1 = block.mlp.fc1 | |
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), | |
fc1.bias is not None, fc1.weight.device) | |
indexes = l1_max_indexes(fc1.weight.data, 0) | |
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
if fc1.bias is not None: | |
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
set_module(fm_vit, f'blocks.{block_i}.mlp.fc1', new_fc1) | |
fc2 = block.mlp.fc2 | |
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, | |
fc2.bias is not None, fc2.weight.device) | |
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) | |
if fc2.bias is not None: | |
new_fc2.bias.data.copy_(fc2.bias.data) | |
set_module(fm_vit, f'blocks.{block_i}.mlp.fc2', new_fc2) | |
# reduce dim_embedding | |
# if name.endswith('patch_embed.proj'): | |
# continue | |
# new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride, | |
# module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, | |
# module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes] | |
# fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes] | |
# elif isinstance(module, nn.Linear): | |
# if 'head' in name: | |
# continue | |
# new_layer = nn.Linear(_f(module.in_features), module.out_features, | |
# module.bias is not None, module.weight.device) | |
# new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data) | |
# else: | |
# first_attn = False | |
# if first_attn: | |
# first_attn = False | |
# new_layer = nn.Linear(module.in_features, _f(module.out_features), | |
# module.bias is not None, module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# else: | |
# new_layer = nn.Linear(_f(module.in_features), _f(module.out_features), | |
# module.bias is not None, module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'): | |
# new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# else: | |
# continue | |
# original_layer_str = str(module) | |
# set_module(fm_vit, name, new_layer) | |
# logger.debug(f'set_module, {name}, {new_layer}') | |
# logger.debug(f'slim {name} from {original_layer_str} to {new_layer}') | |
return fm_vit |