import copy import inspect from typing import List, Union import torch import torch.nn as nn import lightning from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available from mmpl.registry import LOGGERS def register_pl_loggers() -> List[str]: """Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ pl_loggers = [] for module_name in dir(lightning.pytorch.loggers): if module_name.startswith('__'): continue _logger = getattr(lightning.pytorch.loggers, module_name) if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger): LOGGERS.register_module(module=_logger) pl_loggers.append(module_name) return pl_loggers PL_LOGGERS = register_pl_loggers() def register_dadaptation_optimizers() -> List[str]: """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ dadaptation_optimizers = [] try: import dadaptation except ImportError: pass else: for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: _optim = getattr(dadaptation, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module(module=_optim) dadaptation_optimizers.append(module_name) return dadaptation_optimizers # DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() def register_lion_optimizers() -> List[str]: """Register Lion optimizer to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ optimizers = [] try: from lion_pytorch import Lion except ImportError: pass else: OPTIMIZERS.register_module(module=Lion) optimizers.append('Lion') return optimizers # LION_OPTIMIZERS = register_lion_optimizers() def build_optim_wrapper(model: nn.Module, cfg: Union[dict, Config, ConfigDict]): """Build function of OptimWrapper. If ``constructor`` is set in the ``cfg``, this method will build an optimizer wrapper constructor, and use optimizer wrapper constructor to build the optimizer wrapper. If ``constructor`` is not set, the ``DefaultOptimWrapperConstructor`` will be used by default. Args: model (nn.Module): Model to be optimized. cfg (dict): Config of optimizer wrapper, optimizer constructor and optimizer. Returns: OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) constructor_type = optim_wrapper_cfg.pop('constructor', 'DefaultOptimWrapperConstructor') paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports # mixed precision training, here we turn on mixed precision by default # on the NPU to make the training normal if is_npu_available(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( dict( type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper