|
|
|
import torch |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload |
|
from typing_extensions import deprecated |
|
from ._functions import Scatter, Gather |
|
|
|
__all__ = ['scatter', 'scatter_kwargs', 'gather'] |
|
|
|
|
|
@deprecated( |
|
"`is_namedtuple` is deprecated, please use the python checks instead", |
|
category=FutureWarning, |
|
) |
|
def is_namedtuple(obj: Any) -> bool: |
|
|
|
return _is_namedtuple(obj) |
|
|
|
def _is_namedtuple(obj: Any) -> bool: |
|
|
|
return ( |
|
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") |
|
) |
|
|
|
|
|
T = TypeVar("T", dict, list, tuple) |
|
|
|
|
|
@overload |
|
def scatter( |
|
inputs: torch.Tensor, |
|
target_gpus: Sequence[Union[int, torch.device]], |
|
dim: int = ..., |
|
) -> Tuple[torch.Tensor, ...]: |
|
... |
|
|
|
@overload |
|
def scatter(inputs: T, target_gpus: Sequence[Union[int, torch.device]], dim: int = ...) -> List[T]: |
|
... |
|
|
|
def scatter(inputs, target_gpus, dim=0): |
|
r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. |
|
|
|
Duplicates references to objects that are not tensors. |
|
""" |
|
def scatter_map(obj): |
|
if isinstance(obj, torch.Tensor): |
|
return Scatter.apply(target_gpus, None, dim, obj) |
|
if _is_namedtuple(obj): |
|
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] |
|
if isinstance(obj, tuple) and len(obj) > 0: |
|
return list(zip(*map(scatter_map, obj))) |
|
if isinstance(obj, list) and len(obj) > 0: |
|
return [list(i) for i in zip(*map(scatter_map, obj))] |
|
if isinstance(obj, dict) and len(obj) > 0: |
|
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] |
|
return [obj for _ in target_gpus] |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
res = scatter_map(inputs) |
|
finally: |
|
scatter_map = None |
|
return res |
|
|
|
|
|
def scatter_kwargs( |
|
inputs: Tuple[Any, ...], |
|
kwargs: Optional[Dict[str, Any]], |
|
target_gpus: Sequence[Union[int, torch.device]], |
|
dim: int = 0, |
|
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: |
|
r"""Scatter with support for kwargs dictionary.""" |
|
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] |
|
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] |
|
if len(scattered_inputs) < len(scattered_kwargs): |
|
scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs))) |
|
elif len(scattered_kwargs) < len(inputs): |
|
scattered_kwargs.extend({} for _ in range(len(scattered_inputs) - len(scattered_kwargs))) |
|
return tuple(scattered_inputs), tuple(scattered_kwargs) |
|
|
|
|
|
def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any: |
|
r"""Gather tensors from different GPUs on a specified device. |
|
|
|
Use 'cpu' for CPU to avoid a deprecation warning. |
|
""" |
|
def gather_map(outputs): |
|
out = outputs[0] |
|
if isinstance(out, torch.Tensor): |
|
return Gather.apply(target_device, dim, *outputs) |
|
if out is None: |
|
return None |
|
if isinstance(out, dict): |
|
if not all(len(out) == len(d) for d in outputs): |
|
raise ValueError('All dicts must have the same number of keys') |
|
return type(out)((k, gather_map([d[k] for d in outputs])) |
|
for k in out) |
|
if _is_namedtuple(out): |
|
return type(out)._make(map(gather_map, zip(*outputs))) |
|
return type(out)(map(gather_map, zip(*outputs))) |
|
|
|
|
|
|
|
try: |
|
res = gather_map(outputs) |
|
finally: |
|
gather_map = None |
|
return res |
|
|