File size: 5,797 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from torch import nn
from abc import ABC, abstractmethod

from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger


class KTakesAll(nn.Module):
    # k means sparsity (the larger k is, the smaller model is)
    def __init__(self, k):
        super(KTakesAll, self).__init__()
        self.k = k
        self.cached_i = None
        
    def forward(self, g: torch.Tensor):
        # k = int(g.size(1) * self.k)
        # i = (-g).topk(k, 1)[1]
        # t = g.scatter(1, i, 0)
        
        k = int(g.size(-1) * self.k)
        i = (-g).topk(k, -1)[1]
        self.cached_i = i
        t = g.scatter(-1, i, 0)
        
        return t
    
    
class Abs(nn.Module):
    def __init__(self):
        super(Abs, self).__init__()
        
    def forward(self, x):
        return x.abs()


class Layer_WrappedWithFBS(nn.Module):
    def __init__(self):
        super(Layer_WrappedWithFBS, self).__init__()
        
        init_sparsity = 0.5
        self.k_takes_all = KTakesAll(init_sparsity)
        
        self.cached_raw_channel_attention = None
        self.cached_channel_attention = None
        self.use_cached_channel_attention = False


class ElasticDNNUtil(ABC):
    @abstractmethod
    def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
        raise NotImplementedError
    
    def convert_raw_dnn_to_master_dnn_with_perf_test(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
        raw_dnn_size = get_model_size(raw_dnn, True)
        master_dnn = self.convert_raw_dnn_to_master_dnn(raw_dnn, r, ignore_layers)
        master_dnn_size = get_model_size(master_dnn, True)
        
        logger.info(f'master DNN w/o FBS ({raw_dnn_size:.3f}MB) -> master DNN w/ FBS ({master_dnn_size:.3f}MB) '
                    f'(↑ {(((master_dnn_size - raw_dnn_size) / raw_dnn_size) * 100.):.2f}%)')
        return master_dnn
    
    def set_master_dnn_inference_via_cached_channel_attention(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                assert module.cached_channel_attention is not None
                module.use_cached_channel_attention = True
    
    def set_master_dnn_dynamic_inference(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                module.cached_channel_attention = None
                module.use_cached_channel_attention = False
    
    def train_only_fbs_of_master_dnn(self, master_dnn: nn.Module):
        fbs_params = []
        for n, p in master_dnn.named_parameters():
            if '.fbs' in n:
                fbs_params += [p]
                p.requires_grad = True
            else:
                p.requires_grad = False
        return fbs_params
    
    def get_accu_l1_reg_of_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        res = 0.
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                res += module.cached_raw_channel_attention.norm(1)
        return res
    
    def get_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        res = {}
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                res[name] = module.cached_raw_channel_attention
        return res

    def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float):
        assert 0 <= sparsity <= 1., sparsity
        for name, module in master_dnn.named_modules():
            if isinstance(module, KTakesAll):
                module.k = sparsity
        logger.debug(f'set master DNN sparsity to {sparsity}')
        
    def clear_cached_channel_attention_in_master_dnn(self, master_dnn: nn.Module):
        for name, module in master_dnn.named_modules():
            if isinstance(module, Layer_WrappedWithFBS):
                module.cached_raw_channel_attention = None
                module.cached_channel_attention = None
                
    @abstractmethod
    def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
        raise NotImplementedError
    
    @abstractmethod
    def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):
        raise NotImplementedError
    
    def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):
        master_dnn_size = get_model_size(master_dnn, True)
        master_dnn_latency = get_model_latency(master_dnn, (1, *list(samples.size())[1:]), 50, 
                                               get_model_device(master_dnn), 50, False)
        
        res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail)
        if not return_detail:
            surrogate_dnn = res
        else:
            surrogate_dnn, unpruned_indexes_of_layers = res
        surrogate_dnn_size = get_model_size(surrogate_dnn, True)
        surrogate_dnn_latency = get_model_latency(surrogate_dnn, (1, *list(samples.size())[1:]), 50, 
                                                  get_model_device(surrogate_dnn), 50, False)

        logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> '
                    f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n'
                    f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, '
                    f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)')
        
        return res