|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import warnings |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from typing import Any, Callable, Optional |
|
|
|
import torch |
|
|
|
from .utils import ( |
|
DistributedType, |
|
DynamoBackend, |
|
GradientAccumulationPlugin, |
|
get_ccl_version, |
|
get_int_from_env, |
|
is_ccl_available, |
|
is_deepspeed_available, |
|
is_fp8_available, |
|
is_mps_available, |
|
is_tpu_available, |
|
parse_choice_from_env, |
|
parse_flag_from_env, |
|
) |
|
from .utils.dataclasses import SageMakerDistributedType |
|
|
|
|
|
if is_tpu_available(check_device=False): |
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
def is_initialized() -> bool: |
|
""" |
|
Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`, |
|
but works as a module method. |
|
""" |
|
return AcceleratorState._shared_state != {} |
|
|
|
|
|
|
|
def do_nothing(*args, **kwargs): |
|
return None |
|
|
|
|
|
|
|
class PartialState: |
|
""" |
|
Singleton class that has information about the current training environment and functions to help with process |
|
control. Designed to be used when only process control and device execution states are needed. Does *not* need to |
|
be initialized from `Accelerator`. |
|
|
|
**Available attributes:** |
|
|
|
- **device** (`torch.device`) -- The device to use. |
|
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently |
|
in use. |
|
- **local_process_index** (`int`) -- The index of the current process on the current server. |
|
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type |
|
of mixed precision being performed. |
|
- **num_processes** (`int`) -- The number of processes currently launched in parallel. |
|
- **process_index** (`int`) -- The index of the current process. |
|
- **is_last_process** (`bool`) -- Whether or not the current process is the last one. |
|
- **is_main_process** (`bool`) -- Whether or not the current process is the main one. |
|
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. |
|
""" |
|
|
|
_shared_state = {} |
|
|
|
def __init__(self, cpu: bool = False, **kwargs): |
|
self.__dict__ = self._shared_state |
|
if not self.initialized: |
|
self._cpu = cpu |
|
self.backend = None |
|
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) |
|
self.device = torch.device(env_device) if env_device is not None else None |
|
if ( |
|
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true" |
|
and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO |
|
and not cpu |
|
): |
|
if os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") == SageMakerDistributedType.DATA_PARALLEL: |
|
self.distributed_type = DistributedType.MULTI_GPU |
|
import smdistributed.dataparallel.torch.torch_smddp |
|
|
|
if not torch.distributed.is_initialized(): |
|
torch.distributed.init_process_group(backend="smddp") |
|
self.backend = "smddp" |
|
self.num_processes = torch.distributed.get_world_size() |
|
self.process_index = torch.distributed.get_rank() |
|
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
|
if self.device is None: |
|
self.device = torch.device("cuda", self.local_process_index) |
|
torch.cuda.set_device(self.device) |
|
elif is_tpu_available() and not cpu: |
|
self.distributed_type = DistributedType.TPU |
|
self.num_processes = xm.xrt_world_size() |
|
self.process_index = xm.get_ordinal() |
|
self.local_process_index = xm.get_local_ordinal() |
|
self.device = xm.xla_device() |
|
elif ( |
|
os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" |
|
and int(os.environ.get("LOCAL_RANK", -1)) != -1 |
|
and not cpu |
|
): |
|
assert ( |
|
is_deepspeed_available() |
|
), "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source" |
|
self.distributed_type = DistributedType.DEEPSPEED |
|
if not torch.distributed.is_initialized(): |
|
from deepspeed import comm as dist |
|
|
|
|
|
kwargs.pop("backend", None) |
|
self.backend = "nccl" |
|
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs) |
|
|
|
self.num_processes = torch.distributed.get_world_size() |
|
self.process_index = torch.distributed.get_rank() |
|
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
|
if self.device is None: |
|
self.device = torch.device("cuda", self.local_process_index) |
|
torch.cuda.set_device(self.device) |
|
self._mixed_precision = "no" |
|
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: |
|
self.distributed_type = DistributedType.MULTI_GPU |
|
if not torch.distributed.is_initialized(): |
|
self.backend = kwargs.pop("backend", "nccl") |
|
|
|
if self.backend is None: |
|
self.backend = "nccl" |
|
torch.distributed.init_process_group(backend=self.backend, **kwargs) |
|
self.num_processes = torch.distributed.get_world_size() |
|
self.process_index = torch.distributed.get_rank() |
|
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
|
if self.device is None: |
|
self.device = torch.device("cuda", self.local_process_index) |
|
torch.cuda.set_device(self.device) |
|
elif get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1: |
|
self.distributed_type = DistributedType.MULTI_CPU |
|
if is_ccl_available() and get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0: |
|
if get_ccl_version() >= "1.12": |
|
import oneccl_bindings_for_pytorch |
|
else: |
|
import torch_ccl |
|
backend = "ccl" |
|
elif torch.distributed.is_mpi_available(): |
|
backend = "mpi" |
|
else: |
|
backend = "gloo" |
|
|
|
rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0) |
|
size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1) |
|
local_rank = get_int_from_env( |
|
["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 |
|
) |
|
local_size = get_int_from_env( |
|
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 |
|
) |
|
self.local_process_index = local_rank |
|
os.environ["RANK"] = str(rank) |
|
os.environ["WORLD_SIZE"] = str(size) |
|
os.environ["LOCAL_RANK"] = str(local_rank) |
|
if not os.environ.get("MASTER_PORT", None): |
|
os.environ["MASTER_PORT"] = "29500" |
|
if not os.environ.get("MASTER_ADDR", None): |
|
if local_size != size and backend != "mpi": |
|
raise ValueError( |
|
"Looks like distributed multinode run but MASTER_ADDR env not set, " |
|
"please try exporting rank 0's hostname as MASTER_ADDR" |
|
) |
|
if not torch.distributed.is_initialized(): |
|
|
|
kwargs.pop("nccl_backend", None) |
|
self.backend = backend |
|
torch.distributed.init_process_group(self.backend, rank=rank, world_size=size, **kwargs) |
|
self.num_processes = torch.distributed.get_world_size() |
|
self.process_index = torch.distributed.get_rank() |
|
self.local_process_index = local_rank |
|
if self.device is None: |
|
self.device = torch.device("cpu") |
|
else: |
|
self.distributed_type = DistributedType.NO |
|
self.num_processes = 1 |
|
self.process_index = self.local_process_index = 0 |
|
|
|
if self.device is None: |
|
self.device = torch.device("cpu") if cpu else self.default_device |
|
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) |
|
|
|
def __repr__(self) -> str: |
|
return ( |
|
f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n" |
|
f"Num processes: {self.num_processes}\n" |
|
f"Process index: {self.process_index}\n" |
|
f"Local process index: {self.local_process_index}\n" |
|
f"Device: {self.device}\n" |
|
) |
|
|
|
@staticmethod |
|
def _reset_state(): |
|
"Resets `_shared_state`, is used internally and should not be called" |
|
PartialState._shared_state = {} |
|
|
|
@property |
|
def initialized(self) -> bool: |
|
"Returns whether the `PartialState` has been initialized" |
|
return self._shared_state != {} |
|
|
|
@property |
|
def use_distributed(self): |
|
""" |
|
Whether the Accelerator is configured for distributed training |
|
""" |
|
return self.distributed_type != DistributedType.NO and self.num_processes > 1 |
|
|
|
@property |
|
def is_last_process(self) -> bool: |
|
"Returns whether the current process is the last one" |
|
return self.process_index == self.num_processes - 1 |
|
|
|
@property |
|
def is_main_process(self) -> bool: |
|
"Returns whether the current process is the main process" |
|
return ( |
|
self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process |
|
) |
|
|
|
@property |
|
def is_local_main_process(self) -> bool: |
|
"Returns whether the current process is the main process on the local node" |
|
return ( |
|
self.local_process_index == 0 |
|
if self.distributed_type != DistributedType.MEGATRON_LM |
|
else self.is_last_process |
|
) |
|
|
|
def wait_for_everyone(self): |
|
""" |
|
Will stop the execution of the current process until every other process has reached that point (so this does |
|
nothing when the script is only run in one process). Useful to do before saving a model. |
|
|
|
Example: |
|
|
|
```python |
|
>>> # Assuming two GPU processes |
|
>>> import time |
|
>>> from accelerate.state import PartialState |
|
|
|
>>> state = PartialState() |
|
>>> if state.is_main_process: |
|
... time.sleep(2) |
|
>>> else: |
|
... print("I'm waiting for the main process to finish its sleep...") |
|
>>> state.wait_for_everyone() |
|
>>> # Should print on every process at the same time |
|
>>> print("Everyone is here") |
|
``` |
|
""" |
|
if self.distributed_type in ( |
|
DistributedType.MULTI_GPU, |
|
DistributedType.MULTI_CPU, |
|
DistributedType.DEEPSPEED, |
|
DistributedType.FSDP, |
|
): |
|
torch.distributed.barrier() |
|
elif self.distributed_type == DistributedType.TPU: |
|
xm.rendezvous("accelerate.utils.wait_for_everyone") |
|
|
|
def _goes_first(self, is_main: bool): |
|
if not is_main: |
|
self.wait_for_everyone() |
|
|
|
yield |
|
|
|
if is_main: |
|
self.wait_for_everyone() |
|
|
|
@contextmanager |
|
def main_process_first(self): |
|
""" |
|
Lets the main process go first inside a with block. |
|
|
|
The other processes will enter the with block after the main process exits. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from accelerate import Accelerator |
|
|
|
>>> accelerator = Accelerator() |
|
>>> with accelerator.main_process_first(): |
|
... # This will be printed first by process 0 then in a seemingly |
|
... # random order by the other processes. |
|
... print(f"This will be printed by process {accelerator.process_index}") |
|
``` |
|
""" |
|
yield from self._goes_first(self.is_main_process) |
|
|
|
@contextmanager |
|
def local_main_process_first(self): |
|
""" |
|
Lets the local main process go inside a with block. |
|
|
|
The other processes will enter the with block after the main process exits. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from accelerate.state import PartialState |
|
|
|
>>> state = PartialState() |
|
>>> with state.local_main_process_first(): |
|
... # This will be printed first by local process 0 then in a seemingly |
|
... # random order by the other processes. |
|
... print(f"This will be printed by process {state.local_process_index}") |
|
``` |
|
""" |
|
yield from self._goes_first(self.is_local_main_process) |
|
|
|
def on_main_process(self, function: Callable[..., Any] = None): |
|
""" |
|
Decorator that only runs the decorated function on the main process. |
|
|
|
Args: |
|
function (`Callable`): The function to decorate. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from accelerate.state import PartialState |
|
|
|
>>> state = PartialState() |
|
|
|
|
|
>>> @state.on_main_process |
|
... def print_something(): |
|
... print("This will be printed by process 0 only.") |
|
|
|
|
|
>>> print_something() |
|
"This will be printed by process 0 only" |
|
``` |
|
""" |
|
if not self.initialized: |
|
raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.") |
|
if self.is_main_process or not self.use_distributed: |
|
return function |
|
return do_nothing |
|
|
|
def on_local_main_process(self, function: Callable[..., Any] = None): |
|
""" |
|
Decorator that only runs the decorated function on the local main process. |
|
|
|
Args: |
|
function (`Callable`): The function to decorate. |
|
|
|
Example: |
|
```python |
|
# Assume we have 2 servers with 4 processes each. |
|
from accelerate.state import PartialState |
|
|
|
state = PartialState() |
|
|
|
|
|
@state.on_local_main_process |
|
def print_something(): |
|
print("This will be printed by process 0 only on each server.") |
|
|
|
|
|
print_something() |
|
# On server 1: |
|
"This will be printed by process 0 only" |
|
# On server 2: |
|
"This will be printed by process 0 only" |
|
``` |
|
""" |
|
if self.is_local_main_process or not self.use_distributed: |
|
return function |
|
return do_nothing |
|
|
|
def on_last_process(self, function: Callable[..., Any]): |
|
""" |
|
Decorator that only runs the decorated function on the last process. |
|
|
|
Args: |
|
function (`Callable`): The function to decorate. |
|
|
|
Example: |
|
```python |
|
# Assume we have 4 processes. |
|
from accelerate.state import PartialState |
|
|
|
state = PartialState() |
|
|
|
|
|
@state.on_last_process |
|
def print_something(): |
|
print(f"Printed on process {state.process_index}") |
|
|
|
|
|
print_something() |
|
"Printed on process 3" |
|
``` |
|
""" |
|
if self.is_last_process or not self.use_distributed: |
|
return function |
|
return do_nothing |
|
|
|
def on_process(self, function: Callable[..., Any] = None, process_index: int = None): |
|
""" |
|
Decorator that only runs the decorated function on the process with the given index. |
|
|
|
Args: |
|
function (`Callable`, `optional`): |
|
The function to decorate. |
|
process_index (`int`, `optional`): |
|
The index of the process on which to run the function. |
|
|
|
Example: |
|
```python |
|
# Assume we have 4 processes. |
|
from accelerate.state import PartialState |
|
|
|
state = PartialState() |
|
|
|
|
|
@state.on_process(process_index=2) |
|
def print_something(): |
|
print(f"Printed on process {state.process_index}") |
|
|
|
|
|
print_something() |
|
"Printed on process 2" |
|
``` |
|
""" |
|
if function is None: |
|
return partial(self.on_process, process_index=process_index) |
|
if (self.process_index == process_index) or (not self.use_distributed): |
|
return function |
|
return do_nothing |
|
|
|
def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None): |
|
""" |
|
Decorator that only runs the decorated function on the process with the given index on the current node. |
|
|
|
Args: |
|
function (`Callable`, *optional*): |
|
The function to decorate. |
|
local_process_index (`int`, *optional*): |
|
The index of the local process on which to run the function. |
|
|
|
Example: |
|
```python |
|
# Assume we have 2 servers with 4 processes each. |
|
from accelerate import Accelerator |
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
@accelerator.on_local_process(local_process_index=2) |
|
def print_something(): |
|
print(f"Printed on process {accelerator.local_process_index}") |
|
|
|
|
|
print_something() |
|
# On server 1: |
|
"Printed on process 2" |
|
# On server 2: |
|
"Printed on process 2" |
|
``` |
|
""" |
|
if function is None: |
|
return partial(self.on_local_process, local_process_index=local_process_index) |
|
if (self.local_process_index == local_process_index) or (not self.use_distributed): |
|
return function |
|
return do_nothing |
|
|
|
def print(self, *args, **kwargs): |
|
if self.is_local_main_process: |
|
print(*args, **kwargs) |
|
|
|
@property |
|
def default_device(self) -> torch.device: |
|
""" |
|
Returns the default device which is: |
|
- MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True. |
|
- CUDA if `torch.cuda.is_available()` |
|
- CPU otherwise |
|
""" |
|
if is_mps_available(): |
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
|
return torch.device("mps") |
|
elif torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
else: |
|
return torch.device("cpu") |
|
|
|
|
|
class AcceleratorState: |
|
""" |
|
Singleton class that has information about the current training environment. |
|
|
|
**Available attributes:** |
|
|
|
- **device** (`torch.device`) -- The device to use. |
|
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently |
|
in use. |
|
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`. |
|
- **local_process_index** (`int`) -- The index of the current process on the current server. |
|
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type |
|
of mixed precision being performed. |
|
- **num_processes** (`int`) -- The number of processes currently launched in parallel. |
|
- **process_index** (`int`) -- The index of the current process. |
|
- **is_last_process** (`bool`) -- Whether or not the current process is the last one. |
|
- **is_main_process** (`bool`) -- Whether or not the current process is the main one. |
|
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. |
|
""" |
|
|
|
_shared_state = {} |
|
|
|
def __init__( |
|
self, |
|
mixed_precision: str = None, |
|
cpu: bool = False, |
|
dynamo_plugin=None, |
|
deepspeed_plugin=None, |
|
fsdp_plugin=None, |
|
megatron_lm_plugin=None, |
|
ipex_plugin=None, |
|
_from_accelerator: bool = False, |
|
**kwargs, |
|
): |
|
self.__dict__ = self._shared_state |
|
if parse_flag_from_env("ACCELERATE_USE_CPU"): |
|
cpu = True |
|
if PartialState._shared_state == {}: |
|
PartialState(cpu, **kwargs) |
|
self.__dict__.update(PartialState._shared_state) |
|
self._check_initialized(mixed_precision, cpu) |
|
if not self.initialized: |
|
self.deepspeed_plugin = None |
|
self.ipex_plugin = None |
|
mixed_precision = ( |
|
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") |
|
if mixed_precision is None |
|
else mixed_precision.lower() |
|
) |
|
if mixed_precision == "fp8" and not is_fp8_available(): |
|
raise ValueError("Using `fp8` precision requires `transformer_engine` to be installed.") |
|
self.dynamo_plugin = dynamo_plugin |
|
if not _from_accelerator: |
|
raise ValueError( |
|
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` " |
|
"before using any functionality from the `accelerate` library." |
|
) |
|
|
|
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision |
|
if self.distributed_type == DistributedType.TPU: |
|
if mixed_precision == "bf16": |
|
if os.environ.get("ACCELERATE_DOWNCAST_BF16"): |
|
os.environ["XLA_USE_BF16"] = str(0) |
|
os.environ["XLA_DOWNCAST_BF16"] = str(1) |
|
self.downcast_bfloat = True |
|
else: |
|
os.environ["XLA_USE_BF16"] = str(1) |
|
os.environ["XLA_DOWNCAST_BF16"] = str(0) |
|
self.downcast_bfloat = False |
|
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: |
|
self.deepspeed_plugin = deepspeed_plugin |
|
elif self.distributed_type == DistributedType.MULTI_GPU: |
|
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": |
|
self.distributed_type = DistributedType.FSDP |
|
if self._mixed_precision != "no": |
|
fsdp_plugin.set_mixed_precision(self._mixed_precision) |
|
self.fsdp_plugin = fsdp_plugin |
|
if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true": |
|
self.distributed_type = DistributedType.MEGATRON_LM |
|
megatron_lm_plugin.set_mixed_precision(self._mixed_precision) |
|
self.megatron_lm_plugin = megatron_lm_plugin |
|
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.NO]: |
|
if self.device.type == "cpu" and ipex_plugin is not None: |
|
self.ipex_plugin = ipex_plugin if ipex_plugin.use_ipex else None |
|
if self.ipex_plugin is not None: |
|
self.ipex_plugin.set_mixed_precision(mixed_precision) |
|
if ( |
|
self.dynamo_plugin.backend != DynamoBackend.NO |
|
and self._mixed_precision == "no" |
|
and self.device.type == "cuda" |
|
): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
PartialState._shared_state["distributed_type"] = self.distributed_type |
|
|
|
@property |
|
def initialized(self) -> bool: |
|
return self._shared_state != PartialState._shared_state |
|
|
|
def __repr__(self): |
|
repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n" |
|
if self.distributed_type == DistributedType.DEEPSPEED: |
|
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n" |
|
return repr |
|
|
|
def _check_initialized(self, mixed_precision=None, cpu=None): |
|
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" |
|
if self.initialized: |
|
err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`." |
|
if cpu and self.device.type != "cpu": |
|
raise ValueError(err.format(flag="cpu=True")) |
|
if ( |
|
mixed_precision is not None |
|
and mixed_precision != self._mixed_precision |
|
and self.distributed_type != DistributedType.DEEPSPEED |
|
): |
|
raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'")) |
|
|
|
|
|
@property |
|
def use_fp16(self): |
|
warnings.warn( |
|
"The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use " |
|
"`AcceleratorState.mixed_precision == 'fp16'` instead.", |
|
FutureWarning, |
|
) |
|
return self._mixed_precision != "no" |
|
|
|
@property |
|
def mixed_precision(self): |
|
if self.distributed_type == DistributedType.DEEPSPEED: |
|
config = self.deepspeed_plugin.deepspeed_config |
|
if config.get("fp16", {}).get("enabled", False): |
|
mixed_precision = "fp16" |
|
elif config.get("bf16", {}).get("enabled", False): |
|
mixed_precision = "bf16" |
|
else: |
|
mixed_precision = "no" |
|
else: |
|
mixed_precision = self._mixed_precision |
|
return mixed_precision |
|
|
|
@staticmethod |
|
def _reset_state(reset_partial_state: bool = False): |
|
"Resets `_shared_state`, is used internally and should not be called" |
|
AcceleratorState._shared_state = {} |
|
if reset_partial_state: |
|
PartialState._reset_state() |
|
|
|
@property |
|
def use_distributed(self): |
|
""" |
|
Whether the Accelerator is configured for distributed training |
|
""" |
|
return PartialState().use_distributed |
|
|
|
@property |
|
def is_last_process(self) -> bool: |
|
"Returns whether the current process is the last one" |
|
return PartialState().is_last_process |
|
|
|
@property |
|
def is_main_process(self) -> bool: |
|
"Returns whether the current process is the main process" |
|
return PartialState().is_main_process |
|
|
|
@property |
|
def is_local_main_process(self) -> bool: |
|
"Returns whether the current process is the main process on the local node" |
|
return PartialState().is_local_main_process |
|
|
|
def wait_for_everyone(self): |
|
PartialState().wait_for_everyone() |
|
|
|
@contextmanager |
|
def main_process_first(self): |
|
""" |
|
Lets the main process go first inside a with block. |
|
|
|
The other processes will enter the with block after the main process exits. |
|
""" |
|
with PartialState().main_process_first(): |
|
yield |
|
|
|
@contextmanager |
|
def local_main_process_first(self): |
|
""" |
|
Lets the local main process go inside a with block. |
|
|
|
The other processes will enter the with block after the main process exits. |
|
""" |
|
with PartialState().local_main_process_first(): |
|
yield |
|
|
|
def print(self, *args, **kwargs): |
|
PartialState().print(*args, **kwargs) |
|
|
|
|
|
class GradientState: |
|
""" |
|
Singleton class that has information related to gradient synchronization for gradient accumulation |
|
|
|
**Available attributes:** |
|
|
|
- **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader |
|
- **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader |
|
- **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices |
|
- **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over |
|
- **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are |
|
being iterated over |
|
- **num_steps** (`int`) -- The number of steps to accumulate over |
|
- **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient |
|
accumulation |
|
""" |
|
|
|
_shared_state = {} |
|
|
|
def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None): |
|
self.__dict__ = self._shared_state |
|
if not self.initialized: |
|
self.sync_gradients = True |
|
self.end_of_dataloader = False |
|
self.remainder = -1 |
|
self.active_dataloader = None |
|
self.dataloader_references = [None] |
|
self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs() |
|
|
|
|
|
if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs(): |
|
self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs() |
|
|
|
@property |
|
def num_steps(self) -> int: |
|
"Returns the number of steps to accumulate over" |
|
return self.plugin_kwargs.get("num_steps", 1) |
|
|
|
@property |
|
def adjust_scheduler(self) -> bool: |
|
"Returns whether the scheduler should be adjusted" |
|
return self.plugin_kwargs.get("adjust_scheduler", False) |
|
|
|
@property |
|
def initialized(self) -> bool: |
|
"Returns whether the `GradientState` has been initialized" |
|
return GradientState._shared_state != {} |
|
|
|
def __repr__(self): |
|
return ( |
|
f"Sync Gradients: {self.sync_gradients}\n" |
|
f"At end of current dataloader: {self.end_of_dataloader}\n" |
|
f"Extra samples added: {self.remainder}\n" |
|
f"Gradient accumulation plugin: {self.plugin_kwargs}\n" |
|
) |
|
|
|
def _set_sync_gradients(self, sync_gradients): |
|
"Private function that sets whether gradients should be synchronized. Users should not have to call this." |
|
self.sync_gradients = sync_gradients |
|
|
|
def _set_end_of_dataloader(self, end_of_dataloader): |
|
"Private function that sets whether the end of the current dataloader has been reached. Users should not have to call this." |
|
self.end_of_dataloader = end_of_dataloader |
|
|
|
def _set_remainder(self, remainder): |
|
"Private function that sets the number of remaining samples at the end of the dataloader. Users should not have to call this." |
|
self.remainder = remainder |
|
|
|
def _add_dataloader(self, dataloader): |
|
"Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this." |
|
self.active_dataloader = dataloader |
|
self.dataloader_references.append(self.active_dataloader) |
|
self._set_end_of_dataloader(False) |
|
|
|
def _remove_dataloader(self, dataloader): |
|
"Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this." |
|
self.dataloader_references.remove(dataloader) |
|
self.active_dataloader = self.dataloader_references[-1] |
|
self._set_end_of_dataloader(True) |
|
|
|
@property |
|
def in_dataloader(self) -> bool: |
|
"Returns whether the current process is in a dataloader" |
|
return self.active_dataloader is not None |
|
|
|
@staticmethod |
|
def _reset_state(): |
|
"Resets `_shared_state`, is used internally and should not be called" |
|
GradientState._shared_state = {} |
|
|