Spaces:
Runtime error
Runtime error
import inspect | |
from typing import List, Union | |
import torch | |
import lightning | |
from mmpl.registry import MODEL_WRAPPERS | |
def register_pl_strategies() -> List[str]: | |
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry. | |
Returns: | |
List[str]: A list of registered callbacks' name. | |
""" | |
pl_strategies = [] | |
for module_name in dir(lightning.pytorch.strategies): | |
if module_name.startswith('__'): | |
continue | |
_strategy = getattr(lightning.pytorch.strategies, module_name) | |
if inspect.isclass(_strategy) and issubclass(_strategy, lightning.pytorch.strategies.Strategy): | |
MODEL_WRAPPERS.register_module(module=_strategy) | |
pl_strategies.append(module_name) | |
return pl_strategies | |
PL_MODEL_WRAPPERS = register_pl_strategies() | |