Spaces:
Runtime error
Runtime error
import copy | |
import inspect | |
from typing import List, Union | |
import torch | |
import torch.nn as nn | |
import lightning | |
from mmengine.config import Config, ConfigDict | |
from mmengine.device import is_npu_available | |
from mmpl.registry import HOOKS | |
def register_pl_hooks() -> List[str]: | |
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry. | |
Returns: | |
List[str]: A list of registered callbacks' name. | |
""" | |
pl_hooks = [] | |
for module_name in dir(lightning.pytorch.callbacks): | |
if module_name.startswith('__'): | |
continue | |
_hook = getattr(lightning.pytorch.callbacks, module_name) | |
if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback): | |
HOOKS.register_module(module=_hook) | |
pl_hooks.append(module_name) | |
return pl_hooks | |
PL_HOOKS = register_pl_hooks() | |