|
|
|
import torch |
|
from torch import nn |
|
from copy import deepcopy |
|
|
|
from .base import FM_to_MD_Util |
|
from utils.common.log import logger |
|
from utils.dl.common.model import set_module, get_module, get_super_module |
|
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size |
|
from utils.common.log import logger |
|
|
|
from transformers.models.bert.modeling_bert import BertSelfAttention |
|
from transformers import BertConfig |
|
|
|
from typing import Optional, Tuple |
|
import math |
|
|
|
class BertSelfAttentionPrunable(BertSelfAttention): |
|
def __init__(self): |
|
config = BertConfig.from_pretrained('bert-base-multilingual-cased') |
|
super(BertSelfAttentionPrunable, self).__init__(config) |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) |
|
x = x.view(new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
mixed_query_layer = self.query(hidden_states) |
|
|
|
|
|
|
|
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
|
if is_cross_attention and past_key_value is not None: |
|
|
|
key_layer = past_key_value[0] |
|
value_layer = past_key_value[1] |
|
attention_mask = encoder_attention_mask |
|
elif is_cross_attention: |
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
|
attention_mask = encoder_attention_mask |
|
elif past_key_value is not None: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
|
else: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
|
use_cache = past_key_value is not None |
|
if self.is_decoder: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = (key_layer, value_layer) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
query_length, key_length = query_layer.shape[2], key_layer.shape[2] |
|
if use_cache: |
|
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( |
|
-1, 1 |
|
) |
|
else: |
|
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
distance = position_ids_l - position_ids_r |
|
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
if self.position_embedding_type == "relative_key": |
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores |
|
elif self.position_embedding_type == "relative_key_query": |
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs = attention_probs * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.query.out_features,) |
|
context_layer = context_layer.view(new_context_layer_shape) |
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
if self.is_decoder: |
|
outputs = outputs + (past_key_value,) |
|
return outputs |
|
|
|
@staticmethod |
|
def init_from_exist_self_attn(attn: BertSelfAttention): |
|
|
|
|
|
res = BertSelfAttentionPrunable() |
|
|
|
for attr in dir(attn): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(getattr(attn, attr), nn.Module): |
|
try: |
|
|
|
setattr(res, attr, getattr(attn, attr)) |
|
|
|
except Exception as e: |
|
print(attr, str(e)) |
|
|
|
|
|
|
|
return res |
|
|
|
class FM_to_MD_Bert_Util(FM_to_MD_Util): |
|
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: |
|
fm_vit = deepcopy(fm) |
|
|
|
for block in fm_vit.bert.encoder.layer: |
|
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self)) |
|
|
|
def _f(n): |
|
return int(n // reducing_width_ratio) |
|
|
|
|
|
|
|
|
|
def l1_max_indexes(p: torch.Tensor, dim=0): |
|
assert dim in [0, 1] |
|
assert p.dim() in [1, 2, 4] |
|
|
|
if dim == 1: |
|
p = p.T |
|
|
|
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) |
|
n = p.size(0) |
|
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] |
|
|
|
for block_i, block in enumerate(fm_vit.bert.encoder.layer): |
|
for k in ['query', 'key', 'value']: |
|
qkv = get_module(block, f'attention.self.{k}') |
|
|
|
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), |
|
qkv.bias is not None, qkv.weight.device) |
|
indexes = l1_max_indexes(qkv.weight.data, 0) |
|
|
|
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) |
|
if qkv.bias is not None: |
|
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) |
|
set_module(block, f'attention.self.{k}', new_qkv) |
|
|
|
proj = get_module(block, f'attention.output.dense') |
|
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, |
|
proj.bias is not None, proj.weight.device) |
|
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) |
|
if proj.bias is not None: |
|
new_proj.bias.data.copy_(proj.bias.data) |
|
set_module(block, f'attention.output.dense', new_proj) |
|
|
|
fc1 = get_module(block, f'intermediate.dense') |
|
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), |
|
fc1.bias is not None, fc1.weight.device) |
|
indexes = l1_max_indexes(fc1.weight.data, 0) |
|
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) |
|
if fc1.bias is not None: |
|
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) |
|
set_module(block, f'intermediate.dense', new_fc1) |
|
|
|
fc2 = get_module(block, f'output.dense') |
|
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, |
|
fc2.bias is not None, fc2.weight.device) |
|
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) |
|
if fc2.bias is not None: |
|
new_fc2.bias.data.copy_(fc2.bias.data) |
|
set_module(block, f'output.dense', new_fc2) |
|
|
|
return fm_vit |
|
|
|
def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int, |
|
samples: torch.Tensor) -> nn.Module: |
|
fm_size = get_model_size(fm, True) |
|
fm_latency = self._get_model_latency(fm, samples, 20, |
|
get_model_device(fm), 20, False) |
|
|
|
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
logger.debug(f'inited master DNN: {master_dnn}') |
|
master_dnn_latency = self._get_model_latency(master_dnn, samples, 20, |
|
get_model_device(master_dnn), 20, False) |
|
|
|
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') |
|
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' |
|
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' |
|
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' |
|
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') |
|
|
|
return master_dnn |
|
|
|
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, |
|
device: str, warmup_sample_num: int, return_detail=False): |
|
import time |
|
|
|
if isinstance(model_input_size, tuple): |
|
dummy_input = torch.rand(model_input_size).to(device) |
|
else: |
|
dummy_input = model_input_size |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for _ in range(warmup_sample_num): |
|
model(**dummy_input) |
|
|
|
infer_time_list = [] |
|
|
|
if device == 'cuda' or 'cuda' in str(device): |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
s.record() |
|
model(**dummy_input) |
|
e.record() |
|
torch.cuda.synchronize() |
|
cur_model_infer_time = s.elapsed_time(e) / 1000. |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
else: |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
start = time.time() |
|
model(**dummy_input) |
|
cur_model_infer_time = time.time() - start |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
avg_infer_time = sum(infer_time_list) / sample_num |
|
|
|
if return_detail: |
|
return avg_infer_time, infer_time_list |
|
return avg_infer_time |