File size: 6,260 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from typing import Optional
import torch
from copy import deepcopy
from torch import nn
from utils.common.others import get_cur_time_str
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module
from utils.common.log import logger
from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
import os
from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil
class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS):
def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r):
super(Conv2d_WrappedWithFBS, self).__init__()
self.fbs = nn.Sequential(
Abs(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r),
nn.ReLU(),
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels),
nn.ReLU()
)
self.raw_conv2d = raw_conv2d
self.raw_bn = raw_bn # remember clear the original BNs in the network
nn.init.constant_(self.fbs[5].bias, 1.)
nn.init.kaiming_normal_(self.fbs[5].weight)
def forward(self, x):
raw_x = self.raw_bn(self.raw_conv2d(x))
if self.use_cached_channel_attention and self.cached_channel_attention is not None:
channel_attention = self.cached_channel_attention
else:
self.cached_raw_channel_attention = self.fbs(x)
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
channel_attention = self.cached_channel_attention
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3)
class StaticFBS(nn.Module):
def __init__(self, channel_attention: torch.Tensor):
super(StaticFBS, self).__init__()
assert channel_attention.dim() == 1
self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False)
def forward(self, x):
return x * self.channel_attention
def __str__(self) -> str:
return f'StaticFBS({len(self.channel_attention.size(1))})'
class ElasticCNNUtil(ElasticDNNUtil):
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
model = deepcopy(raw_dnn)
# clear original BNs
num_original_bns = 0
last_conv_name = None
conv_bn_map = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
last_conv_name = name
if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers):
num_original_bns += 1
conv_bn_map[last_conv_name] = name
num_conv = 0
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers):
set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r))
num_conv += 1
assert num_conv == num_original_bns
for bn_layer in conv_bn_map.values():
set_module(model, bn_layer, nn.Identity())
return model
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
return samples[0].unsqueeze(0)
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor):
sample = self.select_most_rep_sample(master_dnn, samples)
assert sample.dim() == 4 and sample.size(0) == 1
master_dnn.eval()
with torch.no_grad():
master_dnn_output = master_dnn(sample)
pruning_info = {}
pruning_masks = {}
for layer_name, layer in master_dnn.named_modules():
if not isinstance(layer, Conv2d_WrappedWithFBS):
continue
cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)}
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data)
w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0)
unpruned_filters_index = w.nonzero(as_tuple=True)[0]
pruning_info[layer_name] = w
cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1.
if layer.raw_conv2d.bias is not None:
cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1.
pruning_masks[layer_name + '.0'] = cur_pruning_mask
surrogate_dnn = deepcopy(master_dnn)
for name, layer in surrogate_dnn.named_modules():
if not isinstance(layer, Conv2d_WrappedWithFBS):
continue
set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity()))
# fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True)
tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth'
torch.save(pruning_masks, tmp_mask_path)
surrogate_dnn.eval()
model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device)
model_speedup.speedup_model()
os.remove(tmp_mask_path)
# add feature boosting module
for layer_name, feature_boosting_w in pruning_info.items():
feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]]
set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w))
surrogate_dnn.eval()
with torch.no_grad():
surrogate_dnn_output = surrogate_dnn(sample)
output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum()
assert output_diff < 1e-4, output_diff
logger.info(f'output diff of master and surrogate DNN: {output_diff}')
return surrogate_dnn
|