File size: 883 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import copy
import inspect
from typing import List, Union

import torch
import torch.nn as nn
import lightning

from mmengine.config import Config, ConfigDict
from mmengine.device import is_npu_available
from mmpl.registry import HOOKS


def register_pl_hooks() -> List[str]:
    """Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.

    Returns:
        List[str]: A list of registered callbacks' name.
    """
    pl_hooks = []
    for module_name in dir(lightning.pytorch.callbacks):
        if module_name.startswith('__'):
            continue
        _hook = getattr(lightning.pytorch.callbacks, module_name)
        if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback):
            HOOKS.register_module(module=_hook)
            pl_hooks.append(module_name)
    return pl_hooks


PL_HOOKS = register_pl_hooks()