from typing import Optional import torch from copy import deepcopy from torch import nn from utils.common.others import get_cur_time_str from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module from utils.common.log import logger from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup import os from .base import Abs, KTakesAll, Layer_WrappedWithFBS, ElasticDNNUtil class Conv2d_WrappedWithFBS(Layer_WrappedWithFBS): def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, r): super(Conv2d_WrappedWithFBS, self).__init__() self.fbs = nn.Sequential( Abs(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), nn.ReLU(), nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), nn.ReLU() ) self.raw_conv2d = raw_conv2d self.raw_bn = raw_bn # remember clear the original BNs in the network nn.init.constant_(self.fbs[5].bias, 1.) nn.init.kaiming_normal_(self.fbs[5].weight) def forward(self, x): raw_x = self.raw_bn(self.raw_conv2d(x)) if self.use_cached_channel_attention and self.cached_channel_attention is not None: channel_attention = self.cached_channel_attention else: self.cached_raw_channel_attention = self.fbs(x) self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) channel_attention = self.cached_channel_attention return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) class StaticFBS(nn.Module): def __init__(self, channel_attention: torch.Tensor): super(StaticFBS, self).__init__() assert channel_attention.dim() == 1 self.channel_attention = nn.Parameter(channel_attention.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) def forward(self, x): return x * self.channel_attention def __str__(self) -> str: return f'StaticFBS({len(self.channel_attention.size(1))})' class ElasticCNNUtil(ElasticDNNUtil): def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): model = deepcopy(raw_dnn) # clear original BNs num_original_bns = 0 last_conv_name = None conv_bn_map = {} for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): last_conv_name = name if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): num_original_bns += 1 conv_bn_map[last_conv_name] = name num_conv = 0 for name, module in model.named_modules(): if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): set_module(model, name, Conv2d_WrappedWithFBS(module, get_module(model, conv_bn_map[name]), r)) num_conv += 1 assert num_conv == num_original_bns for bn_layer in conv_bn_map.values(): set_module(model, bn_layer, nn.Identity()) return model def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): return samples[0].unsqueeze(0) def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor): sample = self.select_most_rep_sample(master_dnn, samples) assert sample.dim() == 4 and sample.size(0) == 1 master_dnn.eval() with torch.no_grad(): master_dnn_output = master_dnn(sample) pruning_info = {} pruning_masks = {} for layer_name, layer in master_dnn.named_modules(): if not isinstance(layer, Conv2d_WrappedWithFBS): continue cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} if layer.raw_conv2d.bias is not None: cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) w = get_module(master_dnn, layer_name).cached_channel_attention.squeeze(0) unpruned_filters_index = w.nonzero(as_tuple=True)[0] pruning_info[layer_name] = w cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. if layer.raw_conv2d.bias is not None: cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. pruning_masks[layer_name + '.0'] = cur_pruning_mask surrogate_dnn = deepcopy(master_dnn) for name, layer in surrogate_dnn.named_modules(): if not isinstance(layer, Conv2d_WrappedWithFBS): continue set_module(surrogate_dnn, name, nn.Sequential(layer.raw_conv2d, layer.raw_bn, nn.Identity())) # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' torch.save(pruning_masks, tmp_mask_path) surrogate_dnn.eval() model_speedup = ModelSpeedup(surrogate_dnn, sample, tmp_mask_path, sample.device) model_speedup.speedup_model() os.remove(tmp_mask_path) # add feature boosting module for layer_name, feature_boosting_w in pruning_info.items(): feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] set_module(surrogate_dnn, layer_name + '.2', StaticFBS(feature_boosting_w)) surrogate_dnn.eval() with torch.no_grad(): surrogate_dnn_output = surrogate_dnn(sample) output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum() assert output_diff < 1e-4, output_diff logger.info(f'output diff of master and surrogate DNN: {output_diff}') return surrogate_dnn