import torch.nn as nn from typing import List from mmengine.logging import MMLogger first_set_requires_grad = True first_set_train = True def set_requires_grad(model: nn.Module, keywords: List[str]): """ notice:key in name! """ requires_grad_names = [] num_params = 0 num_trainable = 0 for name, param in model.named_parameters(): num_params += param.numel() if any(key in name for key in keywords): param.requires_grad = True requires_grad_names.append(name) num_trainable += param.numel() else: param.requires_grad = False global first_set_requires_grad if first_set_requires_grad: logger = MMLogger.get_current_instance() for name in requires_grad_names: logger.info(f"set_requires_grad----{name}") logger.info( f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%" ) first_set_requires_grad = False def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""): train_names = [] for name, child in model.named_children(): fullname = ".".join([prefix, name]) if any(name.startswith(key) for key in keywords): train_names.append(fullname) child.train() else: train_names += _set_train(child, keywords, prefix=fullname) return train_names def set_train(model: nn.Module, keywords: List[str]): """ notice:sub name startwith key! """ model.train(False) train_names = _set_train(model, keywords) global first_set_train if first_set_train: logger = MMLogger.get_current_instance() for train_name in train_names: logger.info(f"set_train----{train_name}") first_set_train = False