Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/distributed
/_composable
/replicate.py
# mypy: allow-untyped-defs | |
import weakref | |
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple | |
import torch | |
import torch.nn as nn | |
from torch.distributed._composable_state import _State | |
from torch.nn.parallel import DistributedDataParallel | |
from .contract import _get_registry, contract | |
_ROOT_MODULE_PREFIX = "" | |
class _ReplicateState(_State): | |
def __init__(self) -> None: | |
super().__init__() | |
self.module: nn.Module = nn.ParameterList() | |
self.has_initialized: bool = False | |
self._param_list: nn.ParameterList = nn.ParameterList() | |
# TODO(@fegin): this variable is originally create for testing, we | |
# should remove this if possible. | |
self._orig_module = self.module | |
self._param_names: List[str] = [] | |
self._no_sync: bool = False | |
self._init_args: Optional[Tuple[Any, ...]] = None | |
self._init_kwargs: Dict[str, Any] = {} | |
self._comm_hook_args: List[Any] = [] | |
def _collect_params( | |
self, | |
module: nn.Module, | |
ignored_modules: Set[nn.Module], | |
ignored_params: Set[nn.Parameter], | |
prefix: str = _ROOT_MODULE_PREFIX, | |
) -> None: | |
# skip if managed by fully_sharded API | |
if _is_fully_sharded(module): | |
return | |
# if a module is ignored, all descendants of the module are ignored. | |
if module in ignored_modules: | |
return | |
recurse_prefix = ( | |
f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX | |
) | |
for n, p in module.named_parameters(recurse=False): | |
if p not in ignored_params: | |
self._param_list.append(p) | |
self._param_names.append(f"{recurse_prefix}{n}") | |
for name, child_module in module.named_children(): | |
self._collect_params( | |
child_module, | |
ignored_modules, | |
ignored_params, | |
prefix=f"{recurse_prefix}{name}", | |
) | |
def lazy_init(self) -> None: | |
def _lazy_init(): | |
assert self._init_args is not None | |
self.init(*self._init_args, **self._init_kwargs) | |
self.register_comm_hook() | |
self._init_args = tuple() | |
self._init_kwargs = {} | |
_lazy_init() | |
def init( | |
self, | |
module: nn.Module, | |
ignored_modules: Set[nn.Module], | |
**kwargs, | |
) -> None: | |
if self.has_initialized: | |
return | |
self.has_initialized = True | |
device_mesh = kwargs.get("device_mesh", None) | |
self.module = module | |
ignored_params = {p for m in ignored_modules for p in m.parameters()} | |
from torch.distributed.tensor.parallel.ddp import _localize_dtensor | |
_localize_dtensor(module) | |
self._collect_params(module, ignored_modules, ignored_params) | |
if "device_id" in kwargs: | |
# replicate() supports a small usability enhancement where | |
# user can pass in device_id as a Union[int, torch.device] even for | |
# CPU devices so users don't have to change code for CPU/GPU runs. | |
# We derive the right device_ids to feed into DDP to support this. | |
if kwargs["device_id"] is not None: | |
device_id = kwargs["device_id"] | |
# Convert to device_ids that DDP expects. | |
if isinstance(device_id, torch.device) and device_id.type == "cpu": | |
# CPU modules receive device_ids None | |
kwargs["device_ids"] = None | |
else: | |
# GPU modules expect device_ids=[cuda_device] | |
kwargs["device_ids"] = [device_id] | |
else: | |
kwargs["device_ids"] = None | |
kwargs.pop("device_id") | |
self._ddp = DistributedDataParallel(self._param_list, **kwargs) | |
# Weakref to the DDP instance is currently only used for testing. | |
replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) | |
def register_comm_hook(self) -> None: | |
for comm_args, comm_kwargs in self._comm_hook_args: | |
self._ddp.register_comm_hook(*comm_args, **comm_kwargs) | |
self._comm_hook_args.clear() | |
def record_init_args(self, *args, **kwargs) -> None: | |
self._init_args = args | |
self._init_kwargs = kwargs | |
def forward_pre_hook( | |
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] | |
) -> Any: | |
if self._init_args or self._init_kwargs: | |
self.lazy_init() | |
self._ddp.require_backward_grad_sync = not self._no_sync | |
return self._ddp._pre_forward(*args, **kwargs) | |
def forward_post_hook( | |
self, | |
module: nn.Module, | |
input: Tuple[torch.Tensor], | |
output: torch.Tensor, | |
) -> torch.Tensor: | |
return self._ddp._post_forward(output) | |
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: | |
raise AssertionError( | |
"DDP does not support deepcopy. Please use state dict for serialization." | |
) | |
# Follow the same pattern as FSDP/fully_shard | |
class DDP: | |
def __new__(cls, *args, **kwargs): | |
""" | |
Override ``__new__`` to remove the DDP class and directly construct | |
the original class for cases like indexing into a container module. | |
""" | |
# Use index 2 since 0 is the dynamically constructed `DDP<...>` class | |
# and index 1 is the `DDP` class itself | |
orig_cls = cls.__mro__[2] | |
return orig_cls.__new__(orig_cls, *args, **kwargs) | |
def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: | |
""" | |
Sets if the module should sync gradients. This can be used to implement | |
gradient accumulation without communication. | |
Args: | |
requires_gradient_sync (bool): Whether to reduce gradients for the | |
module's parameters. | |
""" | |
replicate.state(self)._no_sync = not requires_gradient_sync | |
def register_comm_hook(self, *args, **kwargs) -> None: | |
replicate.state(self)._comm_hook_args.append((args, kwargs)) | |
def replicate( | |
module: nn.Module, | |
ignored_modules: Optional[Iterable[torch.nn.Module]] = None, | |
**kwargs, | |
) -> nn.Module: | |
r"""Replicates a module | |
Args: | |
module (torch.nn.Module): module to replicate | |
Example:: | |
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) | |
>>> module = nn.Linear(3, 3) | |
>>> replicate(module) | |
""" | |
torch._C._log_api_usage_once("torch.distributed.replicate") | |
# TODO(fegin): using kwargs is not a good idea if we would like to make | |
# replicate a formal API to replace DDP. | |
if "device_id" in kwargs: | |
if not isinstance(kwargs["device_id"], (int, torch.device)): | |
raise RuntimeError( | |
"Expected device_id to be int or torch.device, " | |
f"but got {type(kwargs['device_id'])}" | |
) | |
if _is_fully_sharded(module): | |
raise RuntimeError( | |
"Cannot apply `replicate()` on a Module already managed by `fully_shard`" | |
) | |
if ignored_modules is None: | |
ignored_modules = {} | |
else: | |
ignored_modules = set(ignored_modules) | |
state = cast(_ReplicateState, replicate.state(module)) | |
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) | |
device_mesh = kwargs.get("device_mesh", None) | |
if device_mesh is not None: | |
from torch.distributed.device_mesh import _mesh_resources | |
if _mesh_resources.get_parent_mesh(device_mesh) is not None: | |
# TODO: This is a temporary work around to enable DDP + TP. | |
# We should do the logic in DDP so that the 2D implementation is | |
# sound and the state_dict works out of the box. | |
# | |
# This won't conflict with what is done in DDP class as the module | |
# replicate is going to pass is NOT the original module. | |
from torch.distributed.tensor.parallel.ddp import ( | |
_localize_dtensor, | |
_reconstruct_dtensor, | |
) | |
module.register_forward_pre_hook(_reconstruct_dtensor) | |
module.register_forward_hook(_localize_dtensor) | |
module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] | |
state.record_init_args(module, ignored_modules, **kwargs) | |
# Place DDP leftmost for highest priority in the method resolution order | |
cls = module.__class__ | |
dct = {"__deepcopy__": unimplemented_deepcopy} | |
new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) | |
module.__class__ = new_cls | |
return module | |
def _is_fully_sharded(module: nn.Module) -> bool: | |
r"""Check if module is marked with fully_shard.""" | |
registry = _get_registry(module) | |
if registry is None: | |
return False | |
return "fully_shard" in registry | |