|
from copy import deepcopy |
|
from typing import Optional, Union |
|
import torch |
|
from torch import nn |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
import tqdm |
|
|
|
from utils.dl.common.model import LayerActivation, get_model_device, get_model_size, set_module |
|
from .base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS |
|
from utils.common.log import logger |
|
|
|
|
|
class SqueezeLast(nn.Module): |
|
def __init__(self): |
|
super(SqueezeLast, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.squeeze(-1) |
|
|
|
|
|
class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, raw_conv2d: nn.Conv2d, r): |
|
super(ProjConv_WrappedWithFBS, self).__init__() |
|
|
|
self.fbs = nn.Sequential( |
|
Abs(), |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Flatten(), |
|
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), |
|
nn.ReLU(), |
|
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), |
|
nn.ReLU() |
|
) |
|
|
|
self.raw_conv2d = raw_conv2d |
|
|
|
|
|
nn.init.constant_(self.fbs[5].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[5].weight) |
|
|
|
def forward(self, x): |
|
raw_x = self.raw_conv2d(x) |
|
|
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) |
|
|
|
|
|
class Linear_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, linear: nn.Linear, r): |
|
super(Linear_WrappedWithFBS, self).__init__() |
|
|
|
self.linear = linear |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(linear.in_features, linear.out_features // r), |
|
nn.ReLU(), |
|
nn.Linear(linear.out_features // r, linear.out_features), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
raw_res = self.linear(x) |
|
|
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinearStaticFBS(nn.Module): |
|
def __init__(self, static_channel_attention): |
|
super(LinearStaticFBS, self).__init__() |
|
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 |
|
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) |
|
|
|
def forward(self, x): |
|
|
|
return x * self.static_channel_attention.unsqueeze(1) |
|
|
|
from .cnn import StaticFBS as ConvStaticFBS |
|
|
|
|
|
class ElasticViTUtil(ElasticDNNUtil): |
|
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
assert len(ignore_layers) == 0, 'not supported yet' |
|
|
|
raw_vit = deepcopy(raw_dnn) |
|
|
|
set_module(raw_vit, 'patch_embed.proj', ProjConv_WrappedWithFBS(raw_vit.patch_embed.proj, r)) |
|
|
|
for name, module in raw_vit.named_modules(): |
|
if name.endswith('mlp'): |
|
set_module(module, 'fc1', Linear_WrappedWithFBS(module.fc1, r)) |
|
|
|
return raw_vit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
return samples[0].unsqueeze(0) |
|
|
|
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
sample = self.select_most_rep_sample(master_dnn, samples) |
|
assert sample.dim() == 4 and sample.size(0) == 1 |
|
|
|
|
|
print('WARN: for debug, modify cls_token and pos_embed') |
|
master_dnn.pos_embed.data = torch.zeros_like(master_dnn.pos_embed.data) |
|
|
|
print('before') |
|
master_dnn.eval() |
|
self.clear_cached_channel_attention_in_master_dnn(master_dnn) |
|
|
|
|
|
|
|
hooks = { |
|
'blocks_input': LayerActivation(master_dnn.blocks, True, 'cuda') |
|
} |
|
|
|
with torch.no_grad(): |
|
master_dnn_output = master_dnn(sample) |
|
|
|
for k, v in hooks.items(): |
|
print(f'{k}: {v.input.size()}') |
|
|
|
print('after') |
|
|
|
boosted_vit = master_dnn |
|
|
|
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): |
|
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' |
|
res = channel_attn[0].nonzero(as_tuple=True)[0] |
|
return res |
|
|
|
proj = boosted_vit.patch_embed.proj |
|
proj_unpruned_indexes = get_unpruned_indexes_from_channel_attn( |
|
proj.cached_channel_attention, proj.k_takes_all.k) |
|
|
|
|
|
proj_conv = proj.raw_conv2d |
|
new_proj = nn.Conv2d(proj_conv.in_channels, proj_unpruned_indexes.size(0), proj_conv.kernel_size, proj_conv.stride, proj_conv.padding, |
|
proj_conv.dilation, proj_conv.groups, proj_conv.bias is not None, proj_conv.padding_mode, proj_conv.weight.device) |
|
new_proj.weight.data.copy_(proj_conv.weight.data[proj_unpruned_indexes]) |
|
if new_proj.bias is not None: |
|
new_proj.bias.data.copy_(proj_conv.bias.data[proj_unpruned_indexes]) |
|
set_module(boosted_vit.patch_embed, 'proj', nn.Sequential(new_proj, ConvStaticFBS(proj.cached_channel_attention[0][proj_unpruned_indexes]))) |
|
|
|
|
|
boosted_vit.pos_embed.data = boosted_vit.pos_embed.data[:, :, proj_unpruned_indexes] |
|
boosted_vit.cls_token.data = boosted_vit.cls_token.data[:, :, proj_unpruned_indexes] |
|
|
|
def reduce_linear_output(raw_linear: nn.Linear, layer_name, unpruned_indexes: torch.Tensor): |
|
new_linear = nn.Linear(raw_linear.in_features, unpruned_indexes.size(0), raw_linear.bias is not None) |
|
new_linear.weight.data.copy_(raw_linear.weight.data[unpruned_indexes]) |
|
if raw_linear.bias is not None: |
|
new_linear.bias.data.copy_(raw_linear.bias.data[unpruned_indexes]) |
|
set_module(boosted_vit, layer_name, new_linear) |
|
|
|
def reduce_linear_input(raw_linear: nn.Linear, layer_name, unpruned_indexes: torch.Tensor): |
|
new_linear = nn.Linear(unpruned_indexes.size(0), raw_linear.out_features, raw_linear.bias is not None) |
|
new_linear.weight.data.copy_(raw_linear.weight.data[:, unpruned_indexes]) |
|
if raw_linear.bias is not None: |
|
new_linear.bias.data.copy_(raw_linear.bias.data) |
|
set_module(boosted_vit, layer_name, new_linear) |
|
|
|
def reduce_norm(raw_norm: nn.LayerNorm, layer_name, unpruned_indexes: torch.Tensor): |
|
new_norm = nn.LayerNorm(unpruned_indexes.size(0), raw_norm.eps, raw_norm.elementwise_affine) |
|
new_norm.weight.data.copy_(raw_norm.weight.data[unpruned_indexes]) |
|
new_norm.bias.data.copy_(raw_norm.bias.data[unpruned_indexes]) |
|
set_module(boosted_vit, layer_name, new_norm) |
|
|
|
|
|
for block_i, block in enumerate(boosted_vit.blocks): |
|
attn = block.attn |
|
ff = block.mlp |
|
|
|
reduce_norm(block.norm1, f'blocks.{block_i}.norm1', proj_unpruned_indexes) |
|
reduce_linear_input(attn.qkv, f'blocks.{block_i}.attn.qkv', proj_unpruned_indexes) |
|
reduce_linear_output(attn.proj, f'blocks.{block_i}.attn.proj', proj_unpruned_indexes) |
|
reduce_norm(block.norm2, f'blocks.{block_i}.norm2', proj_unpruned_indexes) |
|
reduce_linear_input(ff.fc1.linear, f'blocks.{block_i}.mlp.fc1.linear', proj_unpruned_indexes) |
|
reduce_linear_output(ff.fc2, f'blocks.{block_i}.mlp.fc2', proj_unpruned_indexes) |
|
|
|
|
|
reduce_norm(boosted_vit.norm, f'norm', proj_unpruned_indexes) |
|
reduce_linear_input(boosted_vit.head, f'head', proj_unpruned_indexes) |
|
|
|
|
|
for block_i, block in enumerate(boosted_vit.blocks): |
|
attn = block.attn |
|
ff = block.mlp |
|
|
|
fc1 = ff.fc1 |
|
fc1_unpruned_indexes = get_unpruned_indexes_from_channel_attn(fc1.cached_channel_attention, fc1.k_takes_all.k) |
|
fc1_linear = fc1.linear |
|
new_linear = nn.Linear(fc1_linear.in_features, fc1_unpruned_indexes.size(0), fc1_linear.bias is not None) |
|
new_linear.weight.data.copy_(fc1_linear.weight.data[fc1_unpruned_indexes]) |
|
if fc1_linear.bias is not None: |
|
new_linear.bias.data.copy_(fc1_linear.bias.data[fc1_unpruned_indexes]) |
|
set_module(boosted_vit, f'blocks.{block_i}.mlp.fc1', nn.Sequential(new_linear, LinearStaticFBS(fc1.cached_channel_attention[:, fc1_unpruned_indexes]))) |
|
|
|
reduce_linear_input(ff.fc2, f'blocks.{block_i}.mlp.fc2', fc1_unpruned_indexes) |
|
|
|
|
|
|
|
surrogate_dnn = boosted_vit |
|
surrogate_dnn.eval() |
|
surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn)) |
|
print(surrogate_dnn) |
|
|
|
|
|
hooks = { |
|
'blocks_input': LayerActivation(surrogate_dnn.blocks, True, 'cuda') |
|
} |
|
|
|
with torch.no_grad(): |
|
surrogate_dnn_output = surrogate_dnn(sample) |
|
|
|
for k, v in hooks.items(): |
|
print(f'{k}: {v.input.size()}') |
|
|
|
output_diff = ((surrogate_dnn_output - master_dnn_output) ** 2).sum() |
|
assert output_diff < 1e-4, output_diff |
|
logger.info(f'output diff of master and surrogate DNN: {output_diff}') |
|
|
|
return boosted_vit |
|
|