File size: 6,260 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import Optional
import torch
from copy import deepcopy
from torch import nn
from utils.common.others import get_cur_time_str
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module
from utils.common.log import logger
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
import os

from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil


class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS):
    def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r):
        super(Conv2d_WrappedWithFBS, self).__init__()
        
        self.fbs = nn.Sequential(
            Abs(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r),
            nn.ReLU(),
            nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels),
            nn.ReLU()
        )
        
        self.raw_conv2d = raw_conv2d
        self.raw_bn = raw_bn # remember clear the original BNs in the network
        
        nn.init.constant_(self.fbs[5].bias, 1.)
        nn.init.kaiming_normal_(self.fbs[5].weight)

    def forward(self, x):
        raw_x = self.raw_bn(self.raw_conv2d(x))
        
        if self.use_cached_channel_attention and self.cached_channel_attention is not None:
            channel_attention = self.cached_channel_attention
        else:
            self.cached_raw_channel_attention = self.fbs(x)
            self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
            
            channel_attention = self.cached_channel_attention
        
        return raw_x * channel_attention.unsqueeze(2).unsqueeze(3)
    
    
class StaticFBS(nn.Module):
    def __init__(self, channel_attention: torch.Tensor):
        super(StaticFBS, self).__init__()
        assert channel_attention.dim() == 1
        self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False)
        
    def forward(self, x):
        return x * self.channel_attention
    
    def __str__(self) -> str:
        return f'StaticFBS({len(self.channel_attention.size(1))})'
    
    
class ElasticCNNUtil(ElasticDNNUtil):
    def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
        model = deepcopy(raw_dnn)

        # clear original BNs
        num_original_bns = 0
        last_conv_name = None
        conv_bn_map = {}
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d):
                last_conv_name = name
            if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers):
                num_original_bns += 1
                conv_bn_map[last_conv_name] = name
        
        num_conv = 0
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers):
                set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r))
                num_conv += 1
                
        assert num_conv == num_original_bns
        
        for bn_layer in conv_bn_map.values():
            set_module(model, bn_layer, nn.Identity())
            
        return model
    
    def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
        return samples[0].unsqueeze(0)
    
    def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor):
        sample = self.select_most_rep_sample(master_dnn, samples)
        assert sample.dim() == 4 and sample.size(0) == 1
        
        master_dnn.eval()
        with torch.no_grad():
            master_dnn_output = master_dnn(sample)
        
        pruning_info = {}
        pruning_masks = {}
        
        for layer_name, layer in master_dnn.named_modules():
            if not isinstance(layer, Conv2d_WrappedWithFBS):
                continue
            
            cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
            if layer.raw_conv2d.bias is not None:
                cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
            
            w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0)
            unpruned_filters_index = w.nonzero(as_tuple=True)[0]
            pruning_info[layer_name] = w
            
            cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
            if layer.raw_conv2d.bias is not None:
                cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
            pruning_masks[layer_name + '.0'] = cur_pruning_mask
        
        surrogate_dnn = deepcopy(master_dnn)
        for name, layer in surrogate_dnn.named_modules():
            if not isinstance(layer, Conv2d_WrappedWithFBS):
                continue
            set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity()))
            
        # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
        tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
        torch.save(pruning_masks, tmp_mask_path)
        surrogate_dnn.eval()
        model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device)
        model_speedup.speedup_model()
        os.remove(tmp_mask_path)
        
        # add feature boosting module
        for layer_name, feature_boosting_w in pruning_info.items():
            feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
            set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w))
            
        surrogate_dnn.eval()
        with torch.no_grad():
            surrogate_dnn_output = surrogate_dnn(sample)
        output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum()
        assert output_diff < 1e-4, output_diff
        logger.info(f'output diff of master and surrogate DNN: {output_diff}')
        
        return surrogate_dnn