File size: 6,544 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

import torch
from torch import nn 
from copy import deepcopy

from .base import FM_to_MD_Util
from utils.common.log import logger
from utils.dl.common.model import set_module, get_module, get_super_module


class FM_to_MD_ViT_Util(FM_to_MD_Util):
    def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module:
        fm_vit = deepcopy(fm)
        
        def _f(n):
            return int(n // reducing_width_ratio)
        
        # def _rand_indexes(n):
            # return torch.randperm(n)[0: int(n // reducing_width_ratio)]
            
        def l1_max_indexes(p: torch.Tensor, dim=0):
            assert dim in [0, 1]
            assert p.dim() in [1, 2, 4]
            
            if dim == 1:
                p = p.T
            
            p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1)
            n = p.size(0)
            return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0]
        
        # first_attn = True
        
        for block_i, block in enumerate(fm_vit.blocks):
            qkv = block.attn.qkv
            
            new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), 
                                qkv.bias is not None, qkv.weight.device)
            indexes = l1_max_indexes(qkv.weight.data, 0)
            
            new_qkv.weight.data.copy_(qkv.weight.data[indexes])
            if qkv.bias is not None:
                new_qkv.bias.data.copy_(qkv.bias.data[indexes])
            set_module(fm_vit, f'blocks.{block_i}.attn.qkv', new_qkv)

            proj = block.attn.proj
            new_proj = nn.Linear(_f(proj.in_features), proj.out_features, 
                                proj.bias is not None, proj.weight.device)
            new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
            if proj.bias is not None:
                new_proj.bias.data.copy_(proj.bias.data)
            set_module(fm_vit, f'blocks.{block_i}.attn.proj', new_proj)
            
            fc1 = block.mlp.fc1
            new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), 
                                fc1.bias is not None, fc1.weight.device)
            indexes = l1_max_indexes(fc1.weight.data, 0)
            new_fc1.weight.data.copy_(fc1.weight.data[indexes])
            if fc1.bias is not None:
                new_fc1.bias.data.copy_(fc1.bias.data[indexes])
            set_module(fm_vit, f'blocks.{block_i}.mlp.fc1', new_fc1)

            fc2 = block.mlp.fc2
            new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, 
                                fc2.bias is not None, fc2.weight.device)
            new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)])
            if fc2.bias is not None:
                new_fc2.bias.data.copy_(fc2.bias.data)
            set_module(fm_vit, f'blocks.{block_i}.mlp.fc2', new_fc2)
            # reduce dim_embedding
            # if name.endswith('patch_embed.proj'):
            #     continue
                
            #     new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride,
            #                          module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode,
            #                          module.weight.device)
                
            #     rand_indexes = l1_max_indexes(module.weight.data)
            #     new_layer.weight.data.copy_(module.weight.data[rand_indexes])
            #     if new_layer.bias is not None:
            #         new_layer.bias.data.copy_(module.bias.data[rand_indexes])
                
            #     fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes]
            #     fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes]
            
            # elif isinstance(module, nn.Linear):
                
            #     if 'head' in name:
            #         continue
                
            #         new_layer = nn.Linear(_f(module.in_features), module.out_features, 
            #                             module.bias is not None, module.weight.device)
            #         new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)])
            #         if new_layer.bias is not None:
            #             new_layer.bias.data.copy_(module.bias.data)
            #     else:
            #         first_attn = False
            #         if first_attn:
            #             first_attn = False
            #             new_layer = nn.Linear(module.in_features, _f(module.out_features), 
            #                                 module.bias is not None, module.weight.device)
                        
            #             rand_indexes = l1_max_indexes(module.weight.data)
            #             new_layer.weight.data.copy_(module.weight.data[rand_indexes])
            #             if new_layer.bias is not None:
            #                 new_layer.bias.data.copy_(module.bias.data[rand_indexes])
            #         else:
            #             new_layer = nn.Linear(_f(module.in_features), _f(module.out_features), 
            #                                 module.bias is not None, module.weight.device)
                        
            #             rand_indexes = l1_max_indexes(module.weight.data)
            #             new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)])
            #             if new_layer.bias is not None:
            #                 new_layer.bias.data.copy_(module.bias.data[rand_indexes])
                        
            # elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'):
            #     new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device)
            #     rand_indexes = l1_max_indexes(module.weight.data)
            #     new_layer.weight.data.copy_(module.weight.data[rand_indexes])
            #     new_layer.bias.data.copy_(module.bias.data[rand_indexes])
                
            # else:
            #     continue
            
            # original_layer_str = str(module)
            # set_module(fm_vit, name, new_layer)
            # logger.debug(f'set_module, {name}, {new_layer}')
            # logger.debug(f'slim {name} from {original_layer_str} to {new_layer}')
        
        return fm_vit