Spaces:
Running
Running
File size: 715 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from enum import Enum
from typing import Union
from .accelerate import AccelerateParallelBackend
from .ptd import PytorchDTensorParallelBackend
from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
class ParallelBackendEnum(str, Enum):
ACCELERATE = "accelerate"
PTD = "ptd"
def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
if backend == ParallelBackendEnum.ACCELERATE:
return AccelerateParallelBackend
if backend == ParallelBackendEnum.PTD:
return PytorchDTensorParallelBackend
raise ValueError(f"Unknown parallel backend: {backend}")
|