File size: 1,858 Bytes
0467378 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch.nn as nn
from typing import List
from mmengine.logging import MMLogger
first_set_requires_grad = True
first_set_train = True
def set_requires_grad(model: nn.Module, keywords: List[str]):
"""
notice:key in name!
"""
requires_grad_names = []
num_params = 0
num_trainable = 0
for name, param in model.named_parameters():
num_params += param.numel()
if any(key in name for key in keywords):
param.requires_grad = True
requires_grad_names.append(name)
num_trainable += param.numel()
else:
param.requires_grad = False
global first_set_requires_grad
if first_set_requires_grad:
logger = MMLogger.get_current_instance()
for name in requires_grad_names:
logger.info(f"set_requires_grad----{name}")
logger.info(
f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%"
)
first_set_requires_grad = False
def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""):
train_names = []
for name, child in model.named_children():
fullname = ".".join([prefix, name])
if any(name.startswith(key) for key in keywords):
train_names.append(fullname)
child.train()
else:
train_names += _set_train(child, keywords, prefix=fullname)
return train_names
def set_train(model: nn.Module, keywords: List[str]):
"""
notice:sub name startwith key!
"""
model.train(False)
train_names = _set_train(model, keywords)
global first_set_train
if first_set_train:
logger = MMLogger.get_current_instance()
for train_name in train_names:
logger.info(f"set_train----{train_name}")
first_set_train = False |