Spaces:
Runtime error
Runtime error
File size: 883 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import copy
import inspect
from typing import List, Union
import torch
import torch.nn as nn
import lightning
from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmpl.registry import HOOKS
def register_pl_hooks() -> List[str]:
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.
Returns:
List[str]: A list of registered callbacks' name.
"""
pl_hooks = []
for module_name in dir(lightning.pytorch.callbacks):
if module_name.startswith('__'):
continue
_hook = getattr(lightning.pytorch.callbacks, module_name)
if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback):
HOOKS.register_module(module=_hook)
pl_hooks.append(module_name)
return pl_hooks
PL_HOOKS = register_pl_hooks()
|