Spaces:
Runtime error
Runtime error
import copy | |
import inspect | |
from typing import List, Union | |
import torch | |
import torch.nn as nn | |
import lightning | |
from mmengine.config import Config, ConfigDict | |
from mmengine.device import is_npu_available | |
from mmpl.registry import LOGGERS | |
def register_pl_loggers() -> List[str]: | |
"""Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry. | |
Returns: | |
List[str]: A list of registered optimizers' name. | |
""" | |
pl_loggers = [] | |
for module_name in dir(lightning.pytorch.loggers): | |
if module_name.startswith('__'): | |
continue | |
_logger = getattr(lightning.pytorch.loggers, module_name) | |
if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger): | |
LOGGERS.register_module(module=_logger) | |
pl_loggers.append(module_name) | |
return pl_loggers | |
PL_LOGGERS = register_pl_loggers() | |
def register_dadaptation_optimizers() -> List[str]: | |
"""Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. | |
Returns: | |
List[str]: A list of registered optimizers' name. | |
""" | |
dadaptation_optimizers = [] | |
try: | |
import dadaptation | |
except ImportError: | |
pass | |
else: | |
for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: | |
_optim = getattr(dadaptation, module_name) | |
if inspect.isclass(_optim) and issubclass(_optim, | |
torch.optim.Optimizer): | |
OPTIMIZERS.register_module(module=_optim) | |
dadaptation_optimizers.append(module_name) | |
return dadaptation_optimizers | |
# DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() | |
def register_lion_optimizers() -> List[str]: | |
"""Register Lion optimizer to the ``OPTIMIZERS`` registry. | |
Returns: | |
List[str]: A list of registered optimizers' name. | |
""" | |
optimizers = [] | |
try: | |
from lion_pytorch import Lion | |
except ImportError: | |
pass | |
else: | |
OPTIMIZERS.register_module(module=Lion) | |
optimizers.append('Lion') | |
return optimizers | |
# LION_OPTIMIZERS = register_lion_optimizers() | |
def build_optim_wrapper(model: nn.Module, | |
cfg: Union[dict, Config, ConfigDict]): | |
"""Build function of OptimWrapper. | |
If ``constructor`` is set in the ``cfg``, this method will build an | |
optimizer wrapper constructor, and use optimizer wrapper constructor to | |
build the optimizer wrapper. If ``constructor`` is not set, the | |
``DefaultOptimWrapperConstructor`` will be used by default. | |
Args: | |
model (nn.Module): Model to be optimized. | |
cfg (dict): Config of optimizer wrapper, optimizer constructor and | |
optimizer. | |
Returns: | |
OptimWrapper: The built optimizer wrapper. | |
""" | |
optim_wrapper_cfg = copy.deepcopy(cfg) | |
constructor_type = optim_wrapper_cfg.pop('constructor', | |
'DefaultOptimWrapperConstructor') | |
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) | |
# Since the current generation of NPU(Ascend 910) only supports | |
# mixed precision training, here we turn on mixed precision by default | |
# on the NPU to make the training normal | |
if is_npu_available(): | |
optim_wrapper_cfg['type'] = 'AmpOptimWrapper' | |
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( | |
dict( | |
type=constructor_type, | |
optim_wrapper_cfg=optim_wrapper_cfg, | |
paramwise_cfg=paramwise_cfg)) | |
optim_wrapper = optim_wrapper_constructor(model) | |
return optim_wrapper | |