Spaces:
Runtime error
Runtime error
File size: 3,632 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import copy
import inspect
from typing import List, Union
import torch
import torch.nn as nn
import lightning
from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmpl.registry import LOGGERS
def register_pl_loggers() -> List[str]:
"""Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
pl_loggers = []
for module_name in dir(lightning.pytorch.loggers):
if module_name.startswith('__'):
continue
_logger = getattr(lightning.pytorch.loggers, module_name)
if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger):
LOGGERS.register_module(module=_logger)
pl_loggers.append(module_name)
return pl_loggers
PL_LOGGERS = register_pl_loggers()
def register_dadaptation_optimizers() -> List[str]:
"""Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
dadaptation_optimizers = []
try:
import dadaptation
except ImportError:
pass
else:
for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']:
_optim = getattr(dadaptation, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
dadaptation_optimizers.append(module_name)
return dadaptation_optimizers
# DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()
def register_lion_optimizers() -> List[str]:
"""Register Lion optimizer to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
optimizers = []
try:
from lion_pytorch import Lion
except ImportError:
pass
else:
OPTIMIZERS.register_module(module=Lion)
optimizers.append('Lion')
return optimizers
# LION_OPTIMIZERS = register_lion_optimizers()
def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]):
"""Build function of OptimWrapper.
If ``constructor`` is set in the ``cfg``, this method will build an
optimizer wrapper constructor, and use optimizer wrapper constructor to
build the optimizer wrapper. If ``constructor`` is not set, the
``DefaultOptimWrapperConstructor`` will be used by default.
Args:
model (nn.Module): Model to be optimized.
cfg (dict): Config of optimizer wrapper, optimizer constructor and
optimizer.
Returns:
OptimWrapper: The built optimizer wrapper.
"""
optim_wrapper_cfg = copy.deepcopy(cfg)
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
# Since the current generation of NPU(Ascend 910) only supports
# mixed precision training, here we turn on mixed precision by default
# on the NPU to make the training normal
if is_npu_available():
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))
optim_wrapper = optim_wrapper_constructor(model)
return optim_wrapper
|