English
XavierJiezou's picture
Add files using upload-large-folder tool
0467378 verified
raw
history blame
1.86 kB
import torch.nn as nn
from typing import List
from mmengine.logging import MMLogger
first_set_requires_grad = True
first_set_train = True
def set_requires_grad(model: nn.Module, keywords: List[str]):
"""
notice:key in name!
"""
requires_grad_names = []
num_params = 0
num_trainable = 0
for name, param in model.named_parameters():
num_params += param.numel()
if any(key in name for key in keywords):
param.requires_grad = True
requires_grad_names.append(name)
num_trainable += param.numel()
else:
param.requires_grad = False
global first_set_requires_grad
if first_set_requires_grad:
logger = MMLogger.get_current_instance()
for name in requires_grad_names:
logger.info(f"set_requires_grad----{name}")
logger.info(
f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%"
)
first_set_requires_grad = False
def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""):
train_names = []
for name, child in model.named_children():
fullname = ".".join([prefix, name])
if any(name.startswith(key) for key in keywords):
train_names.append(fullname)
child.train()
else:
train_names += _set_train(child, keywords, prefix=fullname)
return train_names
def set_train(model: nn.Module, keywords: List[str]):
"""
notice:sub name startwith key!
"""
model.train(False)
train_names = _set_train(model, keywords)
global first_set_train
if first_set_train:
logger = MMLogger.get_current_instance()
for train_name in train_names:
logger.info(f"set_train----{train_name}")
first_set_train = False