|
import threading |
|
import torch |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast |
|
from ..modules import Module |
|
from torch.cuda._utils import _get_device_index |
|
from torch.cuda.amp import autocast |
|
from torch._utils import ExceptionWrapper |
|
|
|
__all__ = ['get_a_var', 'parallel_apply'] |
|
|
|
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]: |
|
if isinstance(obj, torch.Tensor): |
|
return obj |
|
|
|
if isinstance(obj, (list, tuple)): |
|
for result in map(get_a_var, obj): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
if isinstance(obj, dict): |
|
for result in map(get_a_var, obj.items()): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
return None |
|
|
|
def parallel_apply( |
|
modules: Sequence[Module], |
|
inputs: Sequence[Any], |
|
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None, |
|
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None, |
|
) -> List[Any]: |
|
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`. |
|
|
|
Args: |
|
modules (Module): modules to be parallelized |
|
inputs (tensor): inputs to the modules |
|
devices (list of int or torch.device): CUDA devices |
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and |
|
:attr:`devices` (if given) should all have same length. Moreover, each |
|
element of :attr:`inputs` can either be a single object as the only argument |
|
to a module, or a collection of positional arguments. |
|
""" |
|
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' |
|
if kwargs_tup is not None: |
|
assert len(modules) == len(kwargs_tup) |
|
else: |
|
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules) |
|
if devices is not None: |
|
assert len(modules) == len(devices) |
|
else: |
|
devices = [None] * len(modules) |
|
devices = [_get_device_index(x, True) for x in devices] |
|
streams = [torch.cuda.current_stream(x) for x in devices] |
|
lock = threading.Lock() |
|
results = {} |
|
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() |
|
|
|
def _worker( |
|
i: int, |
|
module: Module, |
|
input: Any, |
|
kwargs: Dict[str, Any], |
|
device: Optional[Union[int, torch.device]] = None, |
|
stream: Optional[torch.cuda.Stream] = None, |
|
) -> None: |
|
torch.set_grad_enabled(grad_enabled) |
|
if device is None: |
|
t = get_a_var(input) |
|
if t is None: |
|
with lock: |
|
results[i] = ExceptionWrapper( |
|
where=f"in replica {i}, no device was provided and no tensor input was found; " |
|
"device cannot be resolved") |
|
return |
|
device = t.get_device() |
|
if stream is None: |
|
stream = torch.cuda.current_stream(device) |
|
try: |
|
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled): |
|
|
|
if not isinstance(input, (list, tuple)): |
|
input = (input,) |
|
output = module(*input, **kwargs) |
|
with lock: |
|
results[i] = output |
|
except Exception: |
|
with lock: |
|
results[i] = ExceptionWrapper( |
|
where=f"in replica {i} on device {device}") |
|
|
|
if len(modules) > 1: |
|
threads = [threading.Thread(target=_worker, |
|
args=(i, module, input, kwargs, device, stream)) |
|
for i, (module, input, kwargs, device, stream) in |
|
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] |
|
|
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
else: |
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) |
|
|
|
outputs = [] |
|
for i in range(len(inputs)): |
|
output = results[i] |
|
if isinstance(output, ExceptionWrapper): |
|
output.reraise() |
|
outputs.append(output) |
|
return outputs |
|
|