|
import torch |
|
from ..modules import Module |
|
from . import comm |
|
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast |
|
from torch._utils import _get_device_index |
|
|
|
from collections import OrderedDict |
|
|
|
if TYPE_CHECKING: |
|
from torch.jit import ScriptModule |
|
from torch.jit._state import EnabledProxy |
|
|
|
__all__ = ['replicate'] |
|
|
|
def _is_script_module(module: Module) -> bool: |
|
import torch.jit |
|
return isinstance(module, torch.jit.ScriptModule) |
|
|
|
|
|
def _is_script_method(module: Module) -> bool: |
|
import torch.jit |
|
return isinstance(module, torch._C.ScriptMethod) |
|
|
|
|
|
def _init_script_module() -> "ScriptModule": |
|
import torch.jit |
|
return torch.jit.ScriptModule() |
|
|
|
|
|
def _is_jit_enabled() -> "EnabledProxy": |
|
import torch.jit._state |
|
return torch.jit._state._enabled |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: |
|
|
|
|
|
def descendant_modules(module: Module) -> Iterator[Module]: |
|
gen = module.modules() |
|
next(gen) |
|
return gen |
|
|
|
if not _is_jit_enabled(): |
|
return True |
|
if memo is None: |
|
memo = set() |
|
|
|
|
|
memo.add(module) |
|
if _is_script_module(module): |
|
memo.update(descendant_modules(module)) |
|
return all(_is_script_module(descendant) for |
|
descendant in descendant_modules(module)) |
|
|
|
for child in module.children(): |
|
|
|
|
|
if child in memo: |
|
continue |
|
if not _replicatable_module(child, memo): |
|
return False |
|
|
|
return True |
|
|
|
def _broadcast_coalesced_reshape( |
|
tensors: Sequence[torch.Tensor], |
|
devices: Sequence[Union[int, torch.device]], |
|
detach: bool = False, |
|
) -> List[List[torch.Tensor]]: |
|
from ._functions import Broadcast |
|
if detach: |
|
return comm.broadcast_coalesced(tensors, devices) |
|
else: |
|
|
|
if len(tensors) > 0: |
|
tensor_copies = Broadcast.apply(devices, *tensors) |
|
return [tensor_copies[i:i + len(tensors)] |
|
for i in range(0, len(tensor_copies), len(tensors))] |
|
else: |
|
return [] |
|
|
|
|
|
T = TypeVar("T", bound=Module) |
|
|
|
|
|
def replicate( |
|
network: T, |
|
devices: Sequence[Union[int, torch.device]], |
|
detach: bool = False, |
|
) -> List[T]: |
|
if not _replicatable_module(network): |
|
raise RuntimeError("Cannot replicate network where python modules are " |
|
"childrens of ScriptModule") |
|
|
|
if not devices: |
|
return [] |
|
|
|
devices = [_get_device_index(x, True) for x in devices] |
|
num_replicas = len(devices) |
|
|
|
params = list(network.parameters()) |
|
param_indices = {param: idx for idx, param in enumerate(params)} |
|
param_copies = _broadcast_coalesced_reshape(params, devices, detach) |
|
|
|
buffers = list(network.buffers()) |
|
buffers_rg: List[torch.Tensor] = [] |
|
buffers_not_rg: List[torch.Tensor] = [] |
|
for buf in buffers: |
|
if buf.requires_grad and not detach: |
|
buffers_rg.append(buf) |
|
else: |
|
buffers_not_rg.append(buf) |
|
|
|
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} |
|
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} |
|
|
|
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) |
|
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) |
|
|
|
modules = list(network.modules()) |
|
module_copies: List[List[Module]] = [[] for _ in devices] |
|
module_indices: Dict[Module, int] = {} |
|
|
|
for i, module in enumerate(modules): |
|
module_indices[module] = i |
|
for j in range(num_replicas): |
|
replica = module._replicate_for_data_parallel() |
|
|
|
|
|
|
|
|
|
|
|
replica._former_parameters = OrderedDict() |
|
|
|
module_copies[j].append(replica) |
|
|
|
for i, module in enumerate(modules): |
|
for key, child in module._modules.items(): |
|
if child is None: |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
replica._modules[key] = None |
|
else: |
|
module_idx = module_indices[child] |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
setattr(replica, key, module_copies[j][module_idx]) |
|
for key, param in module._parameters.items(): |
|
if param is None: |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
replica._parameters[key] = None |
|
else: |
|
param_idx = param_indices[param] |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
param_copy = param_copies[j][param_idx] |
|
|
|
|
|
setattr(replica, key, param_copy) |
|
|
|
replica._former_parameters[key] = param_copy |
|
for key, buf in module._buffers.items(): |
|
if buf is None: |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
replica._buffers[key] = None |
|
else: |
|
if buf.requires_grad and not detach: |
|
buffer_copies = buffer_copies_rg |
|
buffer_idx = buffer_indices_rg[buf] |
|
else: |
|
buffer_copies = buffer_copies_not_rg |
|
buffer_idx = buffer_indices_not_rg[buf] |
|
for j in range(num_replicas): |
|
replica = module_copies[j][i] |
|
setattr(replica, key, buffer_copies[j][buffer_idx]) |
|
|
|
return [cast(T, module_copies[j][0]) for j in range(num_replicas)] |
|
|