import torch from torch import nn from copy import deepcopy from .base import FM_to_MD_Util from utils.common.log import logger from utils.dl.common.model import set_module, get_module, get_super_module from utils.dl.common.model import get_model_device, get_model_latency, get_model_size from utils.common.log import logger from transformers.models.vilt.modeling_vilt import ViltSelfAttention from transformers import ViltConfig from typing import Optional, Tuple import math class ViltSelfAttentionPrunable(ViltSelfAttention): def __init__(self): config = ViltConfig.from_pretrained('dandelin/vilt-b32-mlm-itm') super(ViltSelfAttentionPrunable, self).__init__(config) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs @staticmethod def init_from_exist_self_attn(attn: ViltSelfAttention): # print(attn) res = ViltSelfAttentionPrunable() for attr in dir(attn): # if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'): # continue # if isinstance(getattr(attn, attr), nn.Module): # print(attr) if isinstance(getattr(attn, attr), nn.Module): try: # print(attr, 'ok') setattr(res, attr, getattr(attn, attr)) except Exception as e: print(attr, str(e)) return res class FM_to_MD_Vilt_Util(FM_to_MD_Util): def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: fm_vit = deepcopy(fm) for block in fm_vit.vilt.encoder.layer: set_module(block, 'attention.attention', ViltSelfAttentionPrunable.init_from_exist_self_attn(block.attention.attention)) def _f(n): return int(n // reducing_width_ratio) # def _rand_indexes(n): # return torch.randperm(n)[0: int(n // reducing_width_ratio)] def l1_max_indexes(p: torch.Tensor, dim=0): assert dim in [0, 1] assert p.dim() in [1, 2, 4] if dim == 1: p = p.T p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) n = p.size(0) return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] for block_i, block in enumerate(fm_vit.vilt.encoder.layer): for k in ['query', 'key', 'value']: qkv = get_module(block, f'attention.attention.{k}') new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), qkv.bias is not None, qkv.weight.device) indexes = l1_max_indexes(qkv.weight.data, 0) new_qkv.weight.data.copy_(qkv.weight.data[indexes]) if qkv.bias is not None: new_qkv.bias.data.copy_(qkv.bias.data[indexes]) set_module(block, f'attention.attention.{k}', new_qkv) proj = get_module(block, f'attention.output.dense') new_proj = nn.Linear(_f(proj.in_features), proj.out_features, proj.bias is not None, proj.weight.device) new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) if proj.bias is not None: new_proj.bias.data.copy_(proj.bias.data) set_module(block, f'attention.output.dense', new_proj) fc1 = get_module(block, f'intermediate.dense') new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), fc1.bias is not None, fc1.weight.device) indexes = l1_max_indexes(fc1.weight.data, 0) new_fc1.weight.data.copy_(fc1.weight.data[indexes]) if fc1.bias is not None: new_fc1.bias.data.copy_(fc1.bias.data[indexes]) set_module(block, f'intermediate.dense', new_fc1) fc2 = get_module(block, f'output.dense') new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, fc2.bias is not None, fc2.weight.device) new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) if fc2.bias is not None: new_fc2.bias.data.copy_(fc2.bias.data) set_module(block, f'output.dense', new_fc2) return fm_vit def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int, samples: torch.Tensor) -> nn.Module: fm_size = get_model_size(fm, True) fm_latency = self._get_model_latency(fm, samples, 20, get_model_device(fm), 20, False) master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) master_dnn_size = get_model_size(master_dnn, True) logger.debug(f'inited master DNN: {master_dnn}') master_dnn_latency = self._get_model_latency(master_dnn, samples, 20, get_model_device(master_dnn), 20, False) logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') return master_dnn def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, device: str, warmup_sample_num: int, return_detail=False): import time if isinstance(model_input_size, tuple): dummy_input = torch.rand(model_input_size).to(device) else: dummy_input = model_input_size model = model.to(device) model.eval() # warm up with torch.no_grad(): for _ in range(warmup_sample_num): model(**dummy_input) infer_time_list = [] if device == 'cuda' or 'cuda' in str(device): with torch.no_grad(): for _ in range(sample_num): s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) s.record() model(**dummy_input) e.record() torch.cuda.synchronize() cur_model_infer_time = s.elapsed_time(e) / 1000. infer_time_list += [cur_model_infer_time] else: with torch.no_grad(): for _ in range(sample_num): start = time.time() model(**dummy_input) cur_model_infer_time = time.time() - start infer_time_list += [cur_model_infer_time] avg_infer_time = sum(infer_time_list) / sample_num if return_detail: return avg_infer_time, infer_time_list return avg_infer_time