File size: 1,264 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import re
import torch
from torch import nn
def freeze_whole_model(model):
for n, p in model.named_parameters():
p.requires_grad = False
def unfreeze_parameters(model, config):
# targets = '*.proj_*|*_proj*|*itm_head*|*queue*|*adapter*|*temp*|*.cls.*'
targets = ['connector'] # lm_head
if config.get('unfreeze_text_layer_norm', False):
targets = targets + ['self_attn_layer_norm', 'final_layer_norm']
if config.get('unfreeze_vision_layer_norm', False):
targets = targets + ['norm', 'norm1', 'norm2']
print('unfreeze targets:', targets)
for n, p in model.named_parameters():
if any(t in n for t in targets):
# if re.fullmatch(targets, n):
p.requires_grad = True
print(f"{n} is trainable...")
def print_trainable_params_percentage(model):
orig_param_size = sum(p.numel() for p in model.parameters())
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_size = count_parameters(model)
percentage = trainable_size / orig_param_size * 100
print(f"Trainable param percentage: {percentage:.2f}% ({trainable_size}/{orig_param_size})")
return percentage |