|
from transformers import BlipForQuestionAnswering, BlipConfig,BlipModel |
|
import torch |
|
from torch import nn |
|
from abc import ABC, abstractmethod |
|
from copy import deepcopy |
|
from typing import Optional, Union |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
import tqdm |
|
|
|
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, set_module |
|
from utils.dl.common.model import set_module, get_module, get_super_module |
|
from utils.common.log import logger |
|
from new_impl.cv.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util, LoRA |
|
from transformers.models.blip.modeling_blip import BlipAttention |
|
from transformers.models.blip.modeling_blip_text import BlipTextSelfAttention,BlipTextAttention,BlipTextSelfOutput |
|
from new_impl.cv.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util |
|
from new_impl.cv.elasticdnn.model.base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS |
|
|
|
from typing import Optional, Tuple |
|
import math |
|
|
|
def blip(num_classes): |
|
model = BlipForQuestionAnswering.from_pretrained('new_impl/mm/Vis_bert/QuestionAnswering/VisBert_pretrained') |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToQKV_WrappedWithLoRA(nn.Module): |
|
def __init__(self, fc: nn.Linear, ab_r: int): |
|
super(ToQKV_WrappedWithLoRA, self).__init__() |
|
|
|
self.fc = fc |
|
self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) |
|
|
|
def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): |
|
res = nn.Sequential( |
|
LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), |
|
LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) |
|
).to(fc_weight.device) |
|
nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) |
|
nn.init.zeros_(res[1].weight) |
|
return res |
|
|
|
def forward(self, x): |
|
x1 = self.fc(x) |
|
x2 = self.ab(x) |
|
return x1 + x2 |
|
|
|
|
|
class FMLoRA_blip_Util(FMLoRA_Util): |
|
|
|
@torch.no_grad() |
|
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples: dict): |
|
fm.eval() |
|
|
|
|
|
for k, v in samples.items(): |
|
if isinstance(v, torch.Tensor): |
|
samples[k] = v.to(get_model_device(fm)) |
|
|
|
o1 = fm.generate(**samples) |
|
|
|
for name, module in fm.named_modules(): |
|
if name.endswith(('query', 'key', 'value')): |
|
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) |
|
elif name.endswith('.qkv'): |
|
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) |
|
|
|
|
|
o2 = fm.generate(**samples) |
|
|
|
if isinstance(o1, tuple): |
|
o1 = o1[-1] |
|
o2 = o2[-1] |
|
output_diff = ((o1 - o2) ** 2).sum() |
|
assert output_diff < 1e-5 |
|
return fm |
|
|
|
@torch.no_grad() |
|
def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): |
|
fm.eval() |
|
|
|
|
|
for k, v in samples.items(): |
|
if isinstance(v, torch.Tensor): |
|
samples[k] = v.to(get_model_device(fm)) |
|
|
|
o1 = fm.generate(**samples) |
|
|
|
for name, module in fm.named_modules(): |
|
if not isinstance(module, ToQKV_WrappedWithLoRA): |
|
continue |
|
|
|
fc = module.fc |
|
ab = module.ab |
|
|
|
fc.weight.add_(ab[1].weight @ ab[0].weight) |
|
|
|
set_module(fm, name, fc) |
|
|
|
|
|
o2 = fm.generate(**samples) |
|
|
|
if isinstance(o1, tuple): |
|
o1 = o1[-1] |
|
o2 = o2[-1] |
|
output_diff = ((o1 - o2) ** 2).sum() |
|
assert output_diff < 1e-6, output_diff |
|
|
|
return fm |
|
|
|
|
|
|
|
|
|
class blipTextAttentionPrunable(BlipTextSelfAttention): |
|
def __init__(self,is_cross_attention): |
|
config = BlipConfig.from_pretrained('new_impl/mm/Vis_bert/QuestionAnswering/VisBert_pretrained') |
|
super(blipTextAttentionPrunable,self).__init__(config.text_config,is_cross_attention) |
|
|
|
def save_attn_gradients(self, attn_gradients): |
|
self.attn_gradients = attn_gradients |
|
|
|
def get_attn_gradients(self): |
|
return self.attn_gradients |
|
|
|
def save_attention_map(self, attention_map): |
|
self.attention_map = attention_map |
|
|
|
def get_attention_map(self): |
|
return self.attention_map |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
mixed_query_layer = self.query(hidden_states) |
|
|
|
|
|
|
|
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
|
if is_cross_attention: |
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
|
attention_mask = encoder_attention_mask |
|
elif past_key_value is not None: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
|
else: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
|
past_key_value = (key_layer, value_layer) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
seq_length = hidden_states.size()[1] |
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
distance = position_ids_l - position_ids_r |
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
if self.position_embedding_type == "relative_key": |
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores |
|
elif self.position_embedding_type == "relative_key_query": |
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask.to(attention_scores.device) |
|
|
|
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
attention_probs_dropped = self.dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs_dropped = attention_probs_dropped * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs_dropped, value_layer) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
outputs = outputs + (past_key_value,) |
|
return outputs |
|
@staticmethod |
|
def init_from_exist_self_attn(attn: BlipTextSelfAttention,is_cross_attention): |
|
|
|
|
|
res = blipTextAttentionPrunable(is_cross_attention) |
|
|
|
for attr in dir(attn): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(getattr(attn, attr), nn.Module): |
|
try: |
|
|
|
setattr(res, attr, getattr(attn, attr)) |
|
|
|
except Exception as e: |
|
print(attr, str(e)) |
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blipSelfAttentionPrunable(BlipAttention): |
|
def __init__(self): |
|
config = BlipConfig.from_pretrained('new_impl/mm/Vis_bert/QuestionAnswering/VisBert_pretrained') |
|
super(blipSelfAttentionPrunable, self).__init__(config.vision_config) |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, -1).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
head_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size() |
|
|
|
mixed_qkv = ( |
|
self.qkv(hidden_states) |
|
.reshape(bsz, tgt_len, 3, self.num_heads, -1) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] |
|
|
|
|
|
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) |
|
|
|
attention_scores = attention_scores * self.scale |
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs = attention_probs * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) |
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + (-1,) |
|
context_layer = context_layer.reshape(new_context_layer_shape) |
|
|
|
output = self.projection(context_layer) |
|
|
|
outputs = (output, attention_probs) if output_attentions else (output, None) |
|
|
|
return outputs |
|
|
|
@staticmethod |
|
def init_from_exist_self_attn(attn: BlipAttention): |
|
|
|
|
|
res = blipSelfAttentionPrunable() |
|
|
|
for attr in dir(attn): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(getattr(attn, attr), nn.Module): |
|
try: |
|
|
|
setattr(res, attr, getattr(attn, attr)) |
|
|
|
except Exception as e: |
|
print(attr, str(e)) |
|
|
|
|
|
|
|
return res |
|
|
|
|
|
class FM_to_MD_blip_Util(FM_to_MD_Util): |
|
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: |
|
fm_vis = deepcopy(fm) |
|
config = BlipConfig.from_pretrained('new_impl/mm/Vis_bert/QuestionAnswering/VisBert_pretrained') |
|
|
|
|
|
|
|
|
|
|
|
for block in fm_vis.text_decoder.bert.encoder.layer: |
|
set_module(block, 'attention.self', blipTextAttentionPrunable.init_from_exist_self_attn(block.attention.self,False)) |
|
for block in fm_vis.text_decoder.bert.encoder.layer: |
|
set_module(block, 'crossattention.self', blipTextAttentionPrunable.init_from_exist_self_attn(block.crossattention.self,True)) |
|
|
|
|
|
def _f(n): |
|
return int(n // reducing_width_ratio) |
|
|
|
|
|
|
|
|
|
def l1_max_indexes(p: torch.Tensor, dim=0): |
|
assert dim in [0, 1] |
|
assert p.dim() in [1, 2, 4] |
|
|
|
if dim == 1: |
|
p = p.T |
|
|
|
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) |
|
n = p.size(0) |
|
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] |
|
|
|
for block_i, block in enumerate(fm_vis.text_decoder.bert.encoder.layer): |
|
for k in ['query', 'key', 'value']: |
|
qkv = get_module(block, f'attention.self.{k}') |
|
|
|
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), |
|
qkv.bias is not None, qkv.weight.device) |
|
indexes = l1_max_indexes(qkv.weight.data, 0) |
|
|
|
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) |
|
if qkv.bias is not None: |
|
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) |
|
set_module(block, f'attention.self.{k}', new_qkv) |
|
|
|
proj = get_module(block, f'attention.output.dense') |
|
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, |
|
proj.bias is not None, proj.weight.device) |
|
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) |
|
if proj.bias is not None: |
|
new_proj.bias.data.copy_(proj.bias.data) |
|
set_module(block, f'attention.output.dense', new_proj) |
|
|
|
fc1 = get_module(block, f'intermediate.dense') |
|
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), |
|
fc1.bias is not None, fc1.weight.device) |
|
indexes = l1_max_indexes(fc1.weight.data, 0) |
|
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) |
|
if fc1.bias is not None: |
|
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) |
|
set_module(block, f'intermediate.dense', new_fc1) |
|
|
|
fc2 = get_module(block, f'output.dense') |
|
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, |
|
fc2.bias is not None, fc2.weight.device) |
|
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) |
|
if fc2.bias is not None: |
|
new_fc2.bias.data.copy_(fc2.bias.data) |
|
set_module(block, f'output.dense', new_fc2) |
|
|
|
|
|
for block_i, block in enumerate(fm_vis.text_decoder.bert.encoder.layer): |
|
for k in ['query', 'key', 'value']: |
|
qkv = get_module(block, f'crossattention.self.{k}') |
|
|
|
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), |
|
qkv.bias is not None, qkv.weight.device) |
|
indexes = l1_max_indexes(qkv.weight.data, 0) |
|
|
|
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) |
|
if qkv.bias is not None: |
|
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) |
|
set_module(block, f'crossattention.self.{k}', new_qkv) |
|
|
|
proj = get_module(block, f'crossattention.output.dense') |
|
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, |
|
proj.bias is not None, proj.weight.device) |
|
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) |
|
if proj.bias is not None: |
|
new_proj.bias.data.copy_(proj.bias.data) |
|
set_module(block, f'crossattention.output.dense', new_proj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return fm_vis |
|
|
|
def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int, |
|
samples: torch.Tensor) -> nn.Module: |
|
fm_size = get_model_size(fm, True) |
|
fm_latency = self._get_model_latency(fm, samples, 20, |
|
get_model_device(fm), 20, False) |
|
|
|
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
logger.debug(f'inited master DNN: {master_dnn}') |
|
master_dnn_latency = self._get_model_latency(master_dnn, samples, 20, |
|
get_model_device(master_dnn), 20, False) |
|
|
|
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') |
|
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' |
|
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' |
|
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' |
|
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') |
|
|
|
return master_dnn |
|
|
|
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, |
|
device: str, warmup_sample_num: int, return_detail=False): |
|
import time |
|
|
|
if isinstance(model_input_size, tuple): |
|
dummy_input = torch.rand(model_input_size).to(device) |
|
else: |
|
dummy_input = model_input_size |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for _ in range(warmup_sample_num): |
|
model(**dummy_input) |
|
|
|
infer_time_list = [] |
|
|
|
if device == 'cuda' or 'cuda' in str(device): |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
s.record() |
|
model(**dummy_input) |
|
e.record() |
|
torch.cuda.synchronize() |
|
cur_model_infer_time = s.elapsed_time(e) / 1000. |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
else: |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
start = time.time() |
|
model(**dummy_input) |
|
cur_model_infer_time = time.time() - start |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
avg_infer_time = sum(infer_time_list) / sample_num |
|
|
|
if return_detail: |
|
return avg_infer_time, infer_time_list |
|
return avg_infer_time |
|
|
|
|
|
|
|
|
|
class SqueezeLast(nn.Module): |
|
def __init__(self): |
|
super(SqueezeLast, self).__init__() |
|
|
|
def forward(self, x): |
|
return x.squeeze(-1) |
|
|
|
|
|
class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, proj: nn.Conv2d, r): |
|
super(ProjConv_WrappedWithFBS, self).__init__() |
|
|
|
self.proj = proj |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(proj.in_channels, proj.out_channels // r), |
|
nn.ReLU(), |
|
nn.Linear(proj.out_channels // r, proj.out_channels), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
raw_res = self.proj(x) |
|
|
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
class Linear_WrappedWithFBS(Layer_WrappedWithFBS): |
|
def __init__(self, linear: nn.Linear, r): |
|
super(Linear_WrappedWithFBS, self).__init__() |
|
|
|
self.linear = linear |
|
|
|
|
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(linear.in_features, linear.out_features // r), |
|
nn.ReLU(), |
|
nn.Linear(linear.out_features // r, linear.out_features), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
raw_res = self.linear(x) |
|
|
|
return channel_attention.unsqueeze(1) * raw_res |
|
|
|
|
|
class ToQKV_WrappedWithFBS(Layer_WrappedWithFBS): |
|
""" |
|
This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. |
|
It seems different channels of different heads are pruned according to the input. |
|
This is different from "removing some head" or "removing the same channels in each head". |
|
""" |
|
def __init__(self, to_qkv: nn.Linear, r): |
|
super(ToQKV_WrappedWithFBS, self).__init__() |
|
|
|
|
|
|
|
self.to_qk = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 * 2, bias=to_qkv.bias is not None) |
|
self.to_v = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3, bias=to_qkv.bias is not None) |
|
self.to_qk.weight.data.copy_(to_qkv.weight.data[0: to_qkv.out_features // 3 * 2]) |
|
if to_qkv.bias is not None: |
|
self.to_qk.bias.data.copy_(to_qkv.bias.data[0: to_qkv.out_features // 3 * 2]) |
|
self.to_v.weight.data.copy_(to_qkv.weight.data[to_qkv.out_features // 3 * 2: ]) |
|
if to_qkv.bias is not None: |
|
self.to_v.bias.data.copy_(to_qkv.bias.data[to_qkv.out_features // 3 * 2: ]) |
|
|
|
self.fbs = nn.Sequential( |
|
Rearrange('b n d -> b d n'), |
|
Abs(), |
|
nn.AdaptiveAvgPool1d(1), |
|
SqueezeLast(), |
|
nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), |
|
nn.ReLU(), |
|
|
|
nn.Linear(to_qkv.out_features // 3 // r, self.to_v.out_features), |
|
nn.ReLU() |
|
) |
|
|
|
nn.init.constant_(self.fbs[6].bias, 1.) |
|
nn.init.kaiming_normal_(self.fbs[6].weight) |
|
|
|
def forward(self, x): |
|
if self.use_cached_channel_attention and self.cached_channel_attention is not None: |
|
channel_attention = self.cached_channel_attention |
|
else: |
|
self.cached_raw_channel_attention = self.fbs(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
channel_attention = self.cached_channel_attention |
|
|
|
qk = self.to_qk(x) |
|
v = channel_attention.unsqueeze(1) * self.to_v(x) |
|
return torch.cat([qk, v], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticFBS(nn.Module): |
|
def __init__(self, static_channel_attention): |
|
super(StaticFBS, self).__init__() |
|
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 |
|
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) |
|
|
|
def forward(self, x): |
|
|
|
return x * self.static_channel_attention.unsqueeze(1) |
|
|
|
|
|
class ElasticblipUtil(ElasticDNNUtil): |
|
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): |
|
assert len(ignore_layers) == 0, 'not supported yet' |
|
|
|
raw_vit = deepcopy(raw_dnn) |
|
|
|
|
|
|
|
for name, module in raw_vit.named_modules(): |
|
|
|
|
|
if name.endswith('intermediate'): |
|
set_module(module, 'dense', Linear_WrappedWithFBS(module.dense, r)) |
|
elif name.endswith('mlp'): |
|
set_module(module, 'fc1', Linear_WrappedWithFBS(module.fc1, r)) |
|
|
|
return raw_vit |
|
|
|
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return super().set_master_dnn_sparsity(master_dnn, sparsity) |
|
|
|
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): |
|
|
|
|
|
res = {k: v[0: 1] for k, v in samples.items()} |
|
return res |
|
|
|
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
sample = self.select_most_rep_sample(master_dnn, samples) |
|
|
|
|
|
|
|
master_dnn.eval() |
|
self.clear_cached_channel_attention_in_master_dnn(master_dnn) |
|
with torch.no_grad(): |
|
master_dnn_output = master_dnn(**sample) |
|
|
|
|
|
|
|
boosted_vit = deepcopy(master_dnn) |
|
|
|
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): |
|
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' |
|
|
|
|
|
|
|
res = channel_attn[0].nonzero(as_tuple=True)[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res |
|
|
|
unpruned_indexes_of_layers = {} |
|
|
|
|
|
|
|
for block_i, block in enumerate(boosted_vit.text_encoder.encoder.layer): |
|
|
|
|
|
|
|
ff_0 = get_module(block, f'intermediate.dense') |
|
|
|
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] |
|
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) |
|
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) |
|
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) |
|
if ff_0.linear.bias is not None: |
|
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) |
|
set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) |
|
|
|
ff_1 = get_module(block, f'output.dense') |
|
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) |
|
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) |
|
if ff_1.bias is not None: |
|
new_ff_1.bias.data.copy_(ff_1.bias.data) |
|
set_module(block, 'output.dense', new_ff_1) |
|
|
|
unpruned_indexes_of_layers[f'text_encoder.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes |
|
for block_i,block in enumerate(boosted_vit.vision_model.encoder.layers): |
|
|
|
attn = block.self_attn |
|
ff = block.mlp |
|
ff_0 = ff.fc1 |
|
|
|
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] |
|
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) |
|
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) |
|
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) |
|
if ff_0.linear.bias is not None: |
|
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) |
|
set_module(ff, 'fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) |
|
|
|
ff_1 = ff.fc2 |
|
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) |
|
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) |
|
if ff_1.bias is not None: |
|
new_ff_1.bias.data.copy_(ff_1.bias.data) |
|
set_module(ff, 'fc2', new_ff_1) |
|
|
|
unpruned_indexes_of_layers[f'vision_model.encoder.layers.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes |
|
|
|
|
|
for block_i, block in enumerate(boosted_vit.text_decoder.bert.encoder.layer): |
|
|
|
|
|
|
|
ff_0 = get_module(block, f'intermediate.dense') |
|
|
|
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] |
|
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) |
|
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) |
|
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) |
|
if ff_0.linear.bias is not None: |
|
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) |
|
set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) |
|
|
|
ff_1 = get_module(block, f'output.dense') |
|
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) |
|
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) |
|
if ff_1.bias is not None: |
|
new_ff_1.bias.data.copy_(ff_1.bias.data) |
|
set_module(block, 'output.dense', new_ff_1) |
|
|
|
unpruned_indexes_of_layers[f'text_decoder.bert.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes |
|
surrogate_dnn = boosted_vit |
|
surrogate_dnn.eval() |
|
surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn)) |
|
|
|
with torch.no_grad(): |
|
surrogate_dnn_output = surrogate_dnn(**sample) |
|
|
|
output_diff = ((surrogate_dnn_output.logits - master_dnn_output.logits) ** 2).sum() |
|
|
|
logger.info(f'output diff of master and surrogate DNN: {output_diff}') |
|
|
|
|
|
|
|
|
|
if return_detail: |
|
return boosted_vit, unpruned_indexes_of_layers |
|
|
|
return boosted_vit |
|
|
|
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): |
|
master_dnn_size = get_model_size(master_dnn, True) |
|
master_dnn_latency = self._get_model_latency(master_dnn, samples, 50, |
|
get_model_device(master_dnn), 50, False) |
|
|
|
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) |
|
if not return_detail: |
|
surrogate_dnn = res |
|
else: |
|
surrogate_dnn, unpruned_indexes_of_layers = res |
|
surrogate_dnn_size = get_model_size(surrogate_dnn, True) |
|
surrogate_dnn_latency = self._get_model_latency(master_dnn, samples, 50, |
|
get_model_device(master_dnn), 50, False) |
|
|
|
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' |
|
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' |
|
f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, ' |
|
f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') |
|
|
|
return res |
|
|
|
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, |
|
device: str, warmup_sample_num: int, return_detail=False): |
|
import time |
|
|
|
if isinstance(model_input_size, tuple): |
|
dummy_input = torch.rand(model_input_size).to(device) |
|
else: |
|
dummy_input = model_input_size |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for _ in range(warmup_sample_num): |
|
model(**dummy_input) |
|
|
|
infer_time_list = [] |
|
|
|
if device == 'cuda' or 'cuda' in str(device): |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
s.record() |
|
model(**dummy_input) |
|
e.record() |
|
torch.cuda.synchronize() |
|
cur_model_infer_time = s.elapsed_time(e) / 1000. |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
else: |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
start = time.time() |
|
model(**dummy_input) |
|
cur_model_infer_time = time.time() - start |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
avg_infer_time = sum(infer_time_list) / sample_num |
|
|
|
if return_detail: |
|
return avg_infer_time, infer_time_list |
|
return avg_infer_time |
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
from data.dataloader import build_dataloader |
|
|
|
from new_impl.cv.elasticdnn.api.online_model_v2 import ElasticDNN_OnlineModel |
|
|
|
import torch |
|
import sys |
|
from torch import nn |
|
from methods.elasticdnn.api.model import ElasticDNN_OfflineSegFMModel, ElasticDNN_OfflineSegMDModel |
|
from methods.elasticdnn.api.algs.md_pretraining_wo_fbs import ElasticDNN_MDPretrainingWoFBSAlg |
|
from methods.elasticdnn.model.base import ElasticDNNUtil |
|
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util |
|
from methods.elasticdnn.pipeline.offline.fm_to_md.vit import FM_to_MD_ViT_Util |
|
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util |
|
from methods.elasticdnn.pipeline.offline.fm_lora.vit import FMLoRA_ViT_Util |
|
from methods.elasticdnn.model.vilt import ElasticViltUtil |
|
from utils.common.file import ensure_dir |
|
from utils.dl.common.model import LayerActivation, get_module, get_parameter, set_module |
|
from utils.common.exp import save_models_dict_for_init, get_res_save_dir |
|
from data import build_scenario |
|
from utils.dl.common.loss import CrossEntropyLossSoft |
|
import torch.nn.functional as F |
|
from utils.dl.common.env import create_tbwriter |
|
import os |
|
from utils.common.log import logger |
|
from utils.common.data_record import write_json |
|
|
|
from methods.ewc.ewc_elasticfm import OnlineEWCModel |
|
import tqdm |
|
|
|
from copy import deepcopy |
|
|
|
|
|
class ElasticDNN_VQAOnlineModel(ElasticDNN_OnlineModel): |
|
@torch.no_grad() |
|
def sd_feedback_to_md(self, after_da_sd, unpruned_indexes_of_layers): |
|
self.models_dict['sd'] = after_da_sd |
|
self.before_da_md = deepcopy(self.models_dict['md']) |
|
|
|
logger.info('\n\nsurrogate DNN feedback to master DNN...\n\n') |
|
|
|
|
|
cur_unpruned_indexes = None |
|
cur_unpruned_indexes_name = None |
|
|
|
for p_name, p in self.models_dict['sd'].named_parameters(): |
|
matched_md_param = self.get_md_matched_param_of_sd_param(p_name) |
|
logger.debug(f'if feedback: {p_name}') |
|
if matched_md_param is None: |
|
continue |
|
logger.debug(f'start feedback: {p_name}, {p.size()} -> {matched_md_param.size()}') |
|
|
|
|
|
|
|
if p_name in unpruned_indexes_of_layers.keys(): |
|
cur_unpruned_indexes = unpruned_indexes_of_layers[p_name] |
|
cur_unpruned_indexes_name = p_name |
|
|
|
if p.size() != matched_md_param.size(): |
|
logger.debug(f'cur unpruned indexes: {cur_unpruned_indexes_name}, {cur_unpruned_indexes.size()}') |
|
|
|
if p.dim() == 1: |
|
new_p = deepcopy(matched_md_param) |
|
new_p[cur_unpruned_indexes] = p |
|
elif p.dim() == 2: |
|
if p.size(0) < matched_md_param.size(0): |
|
new_p = deepcopy(matched_md_param) |
|
new_p[cur_unpruned_indexes] = p |
|
else: |
|
new_p = deepcopy(matched_md_param) |
|
new_p[:, cur_unpruned_indexes] = p |
|
p = new_p |
|
|
|
assert p.size() == matched_md_param.size(), f'{p.size()}, {matched_md_param.size()}' |
|
|
|
|
|
if False: |
|
continue |
|
|
|
|
|
assert hasattr(self, 'last_trained_cls_indexes') |
|
print(self.last_trained_cls_indexes) |
|
|
|
diff = self._compute_diff(matched_md_param, p) |
|
|
|
matched_md_param.copy_(p) |
|
logger.debug(f'SPECIFIC FOR CL HEAD | end feedback: {p_name}, diff: {diff:.6f}') |
|
else: |
|
diff = self._compute_diff(matched_md_param, (matched_md_param + p) / 2.) |
|
matched_md_param.copy_((matched_md_param + p) / 2.) |
|
logger.debug(f'end feedback: {p_name}, diff: {diff:.6f}') |
|
|
|
def add_cls_in_head(self, num_cls): |
|
head: nn.Linear = get_module(self.models_dict['md'], 'cls') |
|
|
|
new_head = nn.Linear(head.in_features, head.out_features + num_cls, head.bias is not None, device=self.device) |
|
|
|
|
|
|
|
|
|
new_head.weight.data[0: head.out_features] = deepcopy(head.weight.data) |
|
new_head.bias.data[0: head.out_features] = deepcopy(head.bias.data) |
|
set_module(self.models_dict['md'], 'cls', new_head) |
|
set_module(self.models_dict['fm'], 'cls', new_head) |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
from methods.elasticdnn.api.model import VQAScore |
|
vqa_score = VQAScore() |
|
|
|
self.to_eval_mode() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
|
|
vqa_score.update(output, y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') |
|
|
|
return float(vqa_score.compute()) |
|
|
|
def get_elastic_dnn_util(self) -> ElasticDNNUtil: |
|
return ElasticblipUtil() |
|
|
|
def get_fm_matched_param_of_md_param(self, md_param_name): |
|
|
|
self_param_name = md_param_name |
|
fm = self.models_dict['fm'] |
|
if any([k in self_param_name for k in ['fbs', 'ab', 'embeddings']]): |
|
return None |
|
|
|
p = get_parameter(self.models_dict['md'], self_param_name) |
|
if p.dim() == 0: |
|
return None |
|
elif p.dim() == 1 and ('LayerNorm' in self_param_name or 'layernorm' in self_param_name) and 'weight' in self_param_name: |
|
return get_parameter(fm, self_param_name) |
|
|
|
|
|
if ('query' in self_param_name or 'key' in self_param_name or \ |
|
'value' in self_param_name) and ('weight' in self_param_name): |
|
|
|
ss = self_param_name.split('.') |
|
|
|
fm_qkv_name = '.'.join(ss[0: -1]) + '.fc' |
|
fm_qkv = get_module(fm, fm_qkv_name) |
|
|
|
fm_abs_name = '.'.join(ss[0: -1]) + '.ab' |
|
fm_abs = get_module(fm, fm_abs_name) |
|
|
|
|
|
|
|
if not hasattr(fm_abs, '_mul_lora_weight'): |
|
logger.debug(f'set _mul_lora_weight in {fm_abs_name}') |
|
setattr(fm_abs, '_mul_lora_weight', |
|
nn.Parameter(fm_abs[1].weight @ fm_abs[0].weight)) |
|
|
|
return torch.cat([ |
|
fm_qkv.weight.data, |
|
fm_abs._mul_lora_weight.data |
|
], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif 'dense' in self_param_name and 'weight' in self_param_name: |
|
fm_param_name = self_param_name.replace('.linear', '') |
|
return get_parameter(fm, fm_param_name) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
return None |
|
|
|
def update_fm_param(self, md_param_name, cal_new_fm_param_by_md_param): |
|
if not ('query' in md_param_name or 'key' in md_param_name or 'value' in md_param_name): |
|
matched_fm_param_ref = self.get_fm_matched_param_of_md_param(md_param_name) |
|
matched_fm_param_ref.copy_(cal_new_fm_param_by_md_param) |
|
else: |
|
new_fm_attn_weight, new_fm_lora_weight = torch.chunk(cal_new_fm_param_by_md_param, 2, 0) |
|
|
|
ss = md_param_name.split('.') |
|
fm = self.models_dict['fm'] |
|
|
|
|
|
fm_qkv_name = '.'.join(ss[0: -1]) + '.fc' |
|
fm_qkv = get_module(fm, fm_qkv_name) |
|
fm_qkv.weight.data.copy_(new_fm_attn_weight) |
|
|
|
|
|
fm_abs_name = '.'.join(ss[0: -1]) + '.ab' |
|
fm_abs = get_module(fm, fm_abs_name) |
|
fm_abs._mul_lora_weight.data.copy_(new_fm_lora_weight) |
|
|
|
def get_md_matched_param_of_fm_param(self, fm_param_name): |
|
return super().get_md_matched_param_of_fm_param(fm_param_name) |
|
|
|
def get_md_matched_param_of_sd_param(self, sd_param_name): |
|
|
|
|
|
|
|
self_param_name = sd_param_name |
|
md = self.models_dict['md'] |
|
if any([k in self_param_name for k in ['fbs', 'ab', 'embeddings']]): |
|
return None |
|
|
|
p = get_parameter(self.models_dict['sd'], self_param_name) |
|
if p.dim() == 0: |
|
return None |
|
elif p.dim() == 1 and ('LayerNorm' in self_param_name or 'layernorm' in self_param_name) and 'weight' in self_param_name: |
|
return get_parameter(md, self_param_name) |
|
|
|
|
|
if ('query' in self_param_name or 'key' in self_param_name or \ |
|
'value' in self_param_name) and ('weight' in self_param_name): |
|
|
|
|
|
return get_parameter(md, self_param_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif 'intermediate.dense.0.weight' in self_param_name: |
|
fm_param_name = '.'.join(self_param_name.split('.')[0: -2]) + '.linear.weight' |
|
return get_parameter(md, fm_param_name) |
|
|
|
elif 'output.dense' in self_param_name and 'weight' in self_param_name: |
|
fm_param_name = self_param_name |
|
return get_parameter(md, fm_param_name) |
|
|
|
else: |
|
|
|
return None |
|
|
|
def get_task_head_params(self): |
|
head = get_module(self.models_dict['sd'], 'cls') |
|
return list(head.parameters()) |
|
|
|
|
|
|
|
|
|
from typing import List, Tuple |
|
from data.dataloader import build_dataloader |
|
|
|
from methods.elasticdnn.api.online_model_v2 import ElasticDNN_OnlineModel |
|
|
|
import torch |
|
import sys |
|
from torch import nn |
|
from methods.elasticdnn.api.model import ElasticDNN_OfflineSegFMModel, ElasticDNN_OfflineSegMDModel |
|
from methods.elasticdnn.api.algs.md_pretraining_wo_fbs import ElasticDNN_MDPretrainingWoFBSAlg |
|
from methods.elasticdnn.model.base import ElasticDNNUtil |
|
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util |
|
from methods.elasticdnn.pipeline.offline.fm_to_md.vit import FM_to_MD_ViT_Util |
|
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util |
|
from methods.elasticdnn.pipeline.offline.fm_lora.vit import FMLoRA_ViT_Util |
|
from methods.elasticdnn.model.vit import ElasticViTUtil |
|
from utils.common.file import ensure_dir |
|
from utils.dl.common.model import LayerActivation, LayerActivation2, get_module, get_parameter, set_module |
|
from utils.common.exp import save_models_dict_for_init, get_res_save_dir |
|
from data import build_scenario |
|
from utils.dl.common.loss import CrossEntropyLossSoft |
|
import torch.nn.functional as F |
|
from utils.dl.common.env import create_tbwriter |
|
import os |
|
from utils.common.log import logger |
|
from utils.common.data_record import write_json |
|
|
|
from methods.feat_align.main import OnlineFeatAlignModel |
|
import tqdm |
|
from methods.feat_align.mmd import mmd_rbf |
|
from copy import deepcopy |
|
|
|
|
|
class VQAOnlineFeatAlignModel(OnlineFeatAlignModel): |
|
def get_trained_params(self): |
|
qkv_and_norm_params = [p for n, p in self.models_dict['main'].named_parameters() if 'query' in n or 'key' in n or 'value' in n or 'dense' in n or 'LayerNorm' in n] |
|
return qkv_and_norm_params |
|
|
|
def get_feature_hook(self): |
|
return LayerActivation(get_module(self.models_dict['main'], 'cls'), False, self.device) |
|
|
|
def forward_to_get_task_loss(self, x, y): |
|
self.to_train_mode() |
|
o = self.infer(x) |
|
return F.binary_cross_entropy_with_logits(o, y) * y.shape[1] |
|
|
|
|
|
|
|
def get_mmd_loss(self, f1, f2): |
|
return mmd_rbf(f1, f2) |
|
|
|
def infer(self, x, *args, **kwargs): |
|
return self.models_dict['main'](**x) |
|
|
|
def get_accuracy(self, test_loader, *args, **kwargs): |
|
acc = 0 |
|
sample_num = 0 |
|
|
|
from methods.elasticdnn.api.model import VQAScore |
|
vqa_score = VQAScore() |
|
|
|
self.to_eval_mode() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) |
|
for batch_index, (x, y) in pbar: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(self.device) |
|
y = y.to(self.device) |
|
output = self.infer(x) |
|
|
|
vqa_score.update(output, y) |
|
|
|
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') |
|
|
|
return float(vqa_score.compute()) |