KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
883 Bytes
import copy
import inspect
from typing import List, Union
import torch
import torch.nn as nn
import lightning
from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmpl.registry import HOOKS
def register_pl_hooks() -> List[str]:
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.
Returns:
List[str]: A list of registered callbacks' name.
"""
pl_hooks = []
for module_name in dir(lightning.pytorch.callbacks):
if module_name.startswith('__'):
continue
_hook = getattr(lightning.pytorch.callbacks, module_name)
if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback):
HOOKS.register_module(module=_hook)
pl_hooks.append(module_name)
return pl_hooks
PL_HOOKS = register_pl_hooks()