KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
827 Bytes
import inspect
from typing import List, Union
import torch
import lightning
from mmpl.registry import MODEL_WRAPPERS
def register_pl_strategies() -> List[str]:
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.
Returns:
List[str]: A list of registered callbacks' name.
"""
pl_strategies = []
for module_name in dir(lightning.pytorch.strategies):
if module_name.startswith('__'):
continue
_strategy = getattr(lightning.pytorch.strategies, module_name)
if inspect.isclass(_strategy) and issubclass(_strategy, lightning.pytorch.strategies.Strategy):
MODEL_WRAPPERS.register_module(module=_strategy)
pl_strategies.append(module_name)
return pl_strategies
PL_MODEL_WRAPPERS = register_pl_strategies()