RSPrompter / mmpl /models /pler /base_pler.py
KyanChen's picture
Upload 159 files
1c3eb47
import torch
import torch.nn as nn
from lightning.pytorch.utilities import grad_norm
from mmengine import OPTIM_WRAPPERS
from mmengine.optim import build_optim_wrapper, _ParamScheduler
import copy
from torchmetrics import MetricCollection
from mmpl.registry import MODELS, METRICS
import lightning.pytorch as pl
from mmengine.registry import OPTIMIZERS, PARAM_SCHEDULERS
from mmengine.model import BaseModel
@MODELS.register_module()
class BasePLer(pl.LightningModule, BaseModel):
def __init__(
self,
hyperparameters,
data_preprocessor=None,
train_cfg=None,
test_cfg=None,
*args,
**kwargs
):
super().__init__()
self.hyperparameters = hyperparameters
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if data_preprocessor is not None:
if isinstance(data_preprocessor, nn.Module):
self.data_preprocessor = data_preprocessor
elif isinstance(data_preprocessor, dict):
self.data_preprocessor = MODELS.build(data_preprocessor)
else:
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
evaluator_cfg = copy.deepcopy(self.hyperparameters.get('evaluator', None))
if evaluator_cfg is not None:
for k, v in evaluator_cfg.items():
metrics = []
if isinstance(v, dict):
v = [v]
if isinstance(v, list):
for metric_cfg in v:
metric = METRICS.build(metric_cfg)
metrics.append(metric)
else:
raise TypeError('evaluator should be a `dict` or '
f'`list` instance, but got '
f'{type(evaluator_cfg)}')
setattr(self, k, MetricCollection(metrics, prefix=k.split('_')[0]))
def _set_grad(self, need_train_names: list=[], noneed_train_names: list=[]):
for name, param in self.named_parameters():
flag = False
for need_train_name in need_train_names:
if need_train_name in name:
flag = True
for noneed_train_name in noneed_train_names:
if noneed_train_name in name:
flag = False
param.requires_grad_(flag)
not_specific_names = []
for name, param in self.named_parameters():
flag_find = False
for specific_name in need_train_names + noneed_train_names:
if specific_name in name:
flag_find = True
if not flag_find:
not_specific_names.append(name)
if self.local_rank == 0:
not_specific_names = [x.split('.')[0] for x in not_specific_names]
not_specific_names = set(not_specific_names)
print(f"Turning off gradients for names: {noneed_train_names}")
print(f"Turning on gradients for names: {need_train_names}")
print(f"Turning off gradients for not specific names: {not_specific_names}")
def _set_train_module(self, mode=True, need_train_names: list=[]):
self.training = mode
for name, module in self.named_children():
flag = False
for need_train_name in need_train_names:
if need_train_name in name:
flag = True
if flag:
module.train(mode)
else:
module.eval()
return self
def configure_optimizers(self):
optimizer_cfg = copy.deepcopy(self.hyperparameters.get('optimizer'))
base_lr = optimizer_cfg.get('lr')
base_wd = optimizer_cfg.get('weight_decay', None)
sub_models = optimizer_cfg.pop('sub_model', None)
if sub_models is None:
optimizer_cfg['params'] = filter(lambda p: p.requires_grad, self.parameters())
# optimizer_cfg['params'] = self.parameters()
else:
if isinstance(sub_models, str):
sub_models = {sub_models: {}}
if isinstance(sub_models, list):
sub_models = {x: {} for x in sub_models}
assert isinstance(sub_models, dict), f'sub_models should be a dict, but got {type(sub_models)}'
# import ipdb; ipdb.set_trace()
# set training parameters and lr
for sub_model_name, value in sub_models.items():
sub_attrs = sub_model_name.split('.')
sub_model_ = self
# import ipdb; ipdb.set_trace()
for sub_attr in sub_attrs:
sub_model_ = getattr(sub_model_, sub_attr)
# sub_model_ = self.trainer.strategy.model._forward_module.get_submodule(sub_model_name)
if isinstance(sub_model_, torch.nn.Parameter):
# filter(lambda p: p.requires_grad, model.parameters())
# sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_])
sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_])
else:
# import ipdb;ipdb.set_trace()
sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, sub_model_.parameters())
# sub_models[sub_model_name]['params'] = sub_model_.parameters()
lr_mult = value.pop('lr_mult', 1.)
sub_models[sub_model_name]['lr'] = base_lr * lr_mult
if base_wd is not None:
decay_mult = value.pop('decay_mult', 1.)
sub_models[sub_model_name]['weight_decay'] = base_wd * decay_mult
else:
raise ModuleNotFoundError(f'{sub_model_name} not in model')
if self.local_rank == 0:
print('All sub models:')
for name, module in self.named_children():
print(name, end=', ')
print()
print('Needed train models:')
for name, value in sub_models.items():
print(f'{name}', end=', ')
print()
optimizer_cfg['params'] = [value for key, value in sub_models.items()]
optimizer = OPTIMIZERS.build(optimizer_cfg)
if self.local_rank == 0:
print('查看优化器参数')
for param_group in optimizer.param_groups:
print([value.shape for value in param_group['params']], '学习率: ', param_group['lr'])
schedulers = copy.deepcopy(self.hyperparameters.get('param_scheduler', None))
if schedulers is None:
return [optimizer]
param_schedulers = []
total_step = self.trainer.estimated_stepping_batches
for scheduler in schedulers:
if isinstance(scheduler, _ParamScheduler):
param_schedulers.append(scheduler)
elif isinstance(scheduler, dict):
_scheduler = copy.deepcopy(scheduler)
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
optimizer=optimizer,
epoch_length=self.trainer.num_training_batches,
)
)
)
else:
raise TypeError(
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')
return [optimizer], param_schedulers
def lr_scheduler_step(self, scheduler, metric):
pass
def log_grad(self, module=None) -> None:
# Compute the 2-norm for each layer
# If using mixed precision, the gradients are already unscaled here
if module is None:
module = self
norms = grad_norm(module, norm_type=2)
max_grad = max(norms.values())
min_gead = min(norms.values())
self.log_dict(
{'max_grad': max_grad, 'min_grad': min_gead},
prog_bar=True,
logger=True
)
def setup(self, stage: str) -> None:
evaluators = ['train', 'val', 'test']
for evaluator in evaluators:
if hasattr(self, f'{evaluator}_evaluator'):
if hasattr(self.trainer.datamodule, f'{evaluator}_dataset'):
dataset = getattr(self.trainer.datamodule, f'{evaluator}_dataset')
if hasattr(dataset, 'metainfo'):
evaluator_ = getattr(self, f'{evaluator}_evaluator')
for v in evaluator_.values():
if hasattr(v, 'dataset_meta'):
v.dataset_meta = dataset.metainfo
def on_before_optimizer_step(self, optimizer) -> None:
self.log_grad()
def on_validation_epoch_end(self) -> None:
self._log_eval_metrics('val')
def on_test_epoch_end(self) -> None:
self._log_eval_metrics('test')
def on_train_epoch_end(self) -> None:
self._log_eval_metrics('train')
def _log_eval_metrics(self, stage):
assert stage in ['train', 'val', 'test']
if hasattr(self, f'{stage}_evaluator'):
evaluator = getattr(self, f'{stage}_evaluator')
metrics = evaluator.compute()
metrics = {k.lower(): v for k, v in metrics.items()}
keys = []
for k, v in metrics.items():
v = v.view(-1)
for i, data in enumerate(v):
keys.append(f'{k}_{i}')
self.log(f'{k.lower()}_{i}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
evaluator.reset()
if hasattr(self.trainer, 'checkpoint_callback'):
monitor = self.trainer.checkpoint_callback.monitor
if (monitor is not None) and (monitor not in keys):
data = torch.tensor(0., device=self.device)
self.log(f'{monitor}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)