LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
from torch import nn
from copy import deepcopy
from .base import FM_to_MD_Util
from utils.common.log import logger
from utils.dl.common.model import set_module, get_module, get_super_module
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size
from utils.common.log import logger
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers import BertConfig
from typing import Optional, Tuple
import math
class BertSelfAttentionPrunable(BertSelfAttention):
def __init__(self):
config = BertConfig.from_pretrained('bert-base-multilingual-cased')
super(BertSelfAttentionPrunable, self).__init__(config)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.query.out_features,) # NOTE: modified
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
@staticmethod
def init_from_exist_self_attn(attn: BertSelfAttention):
# print(attn)
res = BertSelfAttentionPrunable()
for attr in dir(attn):
# if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'):
# continue
# if isinstance(getattr(attn, attr), nn.Module):
# print(attr)
if isinstance(getattr(attn, attr), nn.Module):
try:
# print(attr, 'ok')
setattr(res, attr, getattr(attn, attr))
except Exception as e:
print(attr, str(e))
return res
class FM_to_MD_Bert_Util(FM_to_MD_Util):
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module:
fm_vit = deepcopy(fm)
for block in fm_vit.bert.encoder.layer:
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self))
def _f(n):
return int(n // reducing_width_ratio)
# def _rand_indexes(n):
# return torch.randperm(n)[0: int(n // reducing_width_ratio)]
def l1_max_indexes(p: torch.Tensor, dim=0):
assert dim in [0, 1]
assert p.dim() in [1, 2, 4]
if dim == 1:
p = p.T
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1)
n = p.size(0)
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0]
for block_i, block in enumerate(fm_vit.bert.encoder.layer):
for k in ['query', 'key', 'value']:
qkv = get_module(block, f'attention.self.{k}')
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
set_module(block, f'attention.self.{k}', new_qkv)
proj = get_module(block, f'attention.output.dense')
new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
proj.bias is not None, proj.weight.device)
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
if proj.bias is not None:
new_proj.bias.data.copy_(proj.bias.data)
set_module(block, f'attention.output.dense', new_proj)
fc1 = get_module(block, f'intermediate.dense')
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features),
fc1.bias is not None, fc1.weight.device)
indexes = l1_max_indexes(fc1.weight.data, 0)
new_fc1.weight.data.copy_(fc1.weight.data[indexes])
if fc1.bias is not None:
new_fc1.bias.data.copy_(fc1.bias.data[indexes])
set_module(block, f'intermediate.dense', new_fc1)
fc2 = get_module(block, f'output.dense')
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features,
fc2.bias is not None, fc2.weight.device)
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)])
if fc2.bias is not None:
new_fc2.bias.data.copy_(fc2.bias.data)
set_module(block, f'output.dense', new_fc2)
return fm_vit
def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int,
samples: torch.Tensor) -> nn.Module:
fm_size = get_model_size(fm, True)
fm_latency = self._get_model_latency(fm, samples, 20,
get_model_device(fm), 20, False)
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio)
master_dnn_size = get_model_size(master_dnn, True)
logger.debug(f'inited master DNN: {master_dnn}')
master_dnn_latency = self._get_model_latency(master_dnn, samples, 20,
get_model_device(master_dnn), 20, False)
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)')
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> '
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n'
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, '
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)')
return master_dnn
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int,
device: str, warmup_sample_num: int, return_detail=False):
import time
if isinstance(model_input_size, tuple):
dummy_input = torch.rand(model_input_size).to(device)
else:
dummy_input = model_input_size
model = model.to(device)
model.eval()
# warm up
with torch.no_grad():
for _ in range(warmup_sample_num):
model(**dummy_input)
infer_time_list = []
if device == 'cuda' or 'cuda' in str(device):
with torch.no_grad():
for _ in range(sample_num):
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
model(**dummy_input)
e.record()
torch.cuda.synchronize()
cur_model_infer_time = s.elapsed_time(e) / 1000.
infer_time_list += [cur_model_infer_time]
else:
with torch.no_grad():
for _ in range(sample_num):
start = time.time()
model(**dummy_input)
cur_model_infer_time = time.time() - start
infer_time_list += [cur_model_infer_time]
avg_infer_time = sum(infer_time_list) / sample_num
if return_detail:
return avg_infer_time, infer_time_list
return avg_infer_time