|
|
|
from __future__ import annotations |
|
|
|
from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union |
|
|
|
import torch |
|
|
|
__all__ = ['Future', 'collect_all', 'wait_all'] |
|
|
|
T = TypeVar("T") |
|
S = TypeVar("S") |
|
|
|
|
|
class _PyFutureMeta(type(torch._C.Future), type(Generic)): |
|
pass |
|
|
|
|
|
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): |
|
r""" |
|
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous |
|
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It |
|
also exposes a set of APIs to add callback functions and set results. |
|
|
|
.. warning:: GPU support is a beta feature, subject to changes. |
|
""" |
|
|
|
def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None): |
|
r""" |
|
Create an empty unset ``Future``. If the future is intended to hold |
|
values containing CUDA tensors, (a superset of) their CUDA devices must |
|
be specified at construction. (This is only supported if |
|
``torch.cuda.is_available()`` returns ``True``). This is needed to |
|
ensure proper CUDA stream synchronization. The child futures, returned |
|
by the ``then`` method, will inherit these devices. |
|
|
|
Args: |
|
devices(``List[Union[int, str, torch.device]]``, optional): the set |
|
of devices on which tensors contained in this future's value are |
|
allowed to reside and on which callbacks are allowed to operate. |
|
""" |
|
if devices is None: |
|
devices = [] |
|
super().__init__([torch.device(d) for d in devices]) |
|
|
|
def done(self) -> bool: |
|
r""" |
|
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it |
|
has a result or an exception. |
|
|
|
If the value contains tensors that reside on GPUs, ``Future.done()`` |
|
will return ``True`` even if the asynchronous kernels that are |
|
populating those tensors haven't yet completed running on the device, |
|
because at such stage the result is already usable, provided one |
|
performs the appropriate synchronizations (see :meth:`wait`). |
|
""" |
|
return super().done() |
|
|
|
def wait(self) -> T: |
|
r""" |
|
Block until the value of this ``Future`` is ready. |
|
|
|
If the value contains tensors that reside on GPUs, then an additional |
|
synchronization is performed with the kernels (executing on the device) |
|
which may be asynchronously populating those tensors. Such sync is |
|
non-blocking, which means that ``wait()`` will insert the necessary |
|
instructions in the current streams to ensure that further operations |
|
enqueued on those streams will be properly scheduled after the async |
|
kernels but, once that is done, ``wait()`` will return, even if those |
|
kernels are still running. No further synchronization is required when |
|
accessing and using the values, as long as one doesn't change streams. |
|
|
|
Returns: |
|
The value held by this ``Future``. If the function (callback or RPC) |
|
creating the value has thrown an error, this ``wait`` method will |
|
also throw an error. |
|
""" |
|
return super().wait() |
|
|
|
def value(self) -> T: |
|
r""" |
|
Obtain the value of an already-completed future. |
|
|
|
This method should only be called after a call to :meth:`wait` has |
|
completed, or inside a callback function passed to :meth:`then`. In |
|
other cases this ``Future`` may not yet hold a value and calling |
|
``value()`` could fail. |
|
|
|
If the value contains tensors that reside on GPUs, then this method will |
|
*not* perform any additional synchronization. This should be done |
|
beforehand, separately, through a call to :meth:`wait` (except within |
|
callbacks, for which it's already being taken care of by :meth:`then`). |
|
|
|
Returns: |
|
The value held by this ``Future``. If the function (callback or RPC) |
|
creating the value has thrown an error, this ``value()`` method will |
|
also throw an error. |
|
""" |
|
return super().value() |
|
|
|
def then(self, callback: Callable[[Future[T]], S]) -> Future[S]: |
|
r""" |
|
Append the given callback function to this ``Future``, which will be run |
|
when the ``Future`` is completed. Multiple callbacks can be added to |
|
the same ``Future``, but the order in which they will be executed cannot |
|
be guaranteed (to enforce a certain order consider chaining: |
|
``fut.then(cb1).then(cb2)``). The callback must take one argument, which |
|
is the reference to this ``Future``. The callback function can use the |
|
:meth:`value` method to get the value. Note that if this ``Future`` is |
|
already completed, the given callback will be run immediately inline. |
|
|
|
If the ``Future``'s value contains tensors that reside on GPUs, the |
|
callback might be invoked while the async kernels that are populating |
|
those tensors haven't yet finished executing on the device. However, the |
|
callback will be invoked with some dedicated streams set as current |
|
(fetched from a global pool) which will be synchronized with those |
|
kernels. Hence any operation performed by the callback on these tensors |
|
will be scheduled on the device after the kernels complete. In other |
|
words, as long as the callback doesn't switch streams, it can safely |
|
manipulate the result without any additional synchronization. This is |
|
similar to the non-blocking behavior of :meth:`wait`. |
|
|
|
Similarly, if the callback returns a value that contains tensors that |
|
reside on a GPU, it can do so even if the kernels that are producing |
|
these tensors are still running on the device, as long as the callback |
|
didn't change streams during its execution. If one wants to change |
|
streams, one must be careful to re-synchronize them with the original |
|
streams, that is, those that were current when the callback was invoked. |
|
|
|
Args: |
|
callback(``Callable``): a ``Callable`` that takes this ``Future`` as |
|
the only argument. |
|
|
|
Returns: |
|
A new ``Future`` object that holds the return value of the |
|
``callback`` and will be marked as completed when the given |
|
``callback`` finishes. |
|
|
|
.. note:: Note that if the callback function throws, either |
|
through the original future being completed with an exception and |
|
calling ``fut.wait()``, or through other code in the callback, the |
|
future returned by ``then`` will be marked appropriately with the |
|
encountered error. However, if this callback later completes |
|
additional futures, those futures are not marked as completed with |
|
an error and the user is responsible for handling completion/waiting |
|
on those futures independently. |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) |
|
>>> def callback(fut): |
|
... print(f"RPC return value is {fut.wait()}.") |
|
>>> fut = torch.futures.Future() |
|
>>> # The inserted callback will print the return value when |
|
>>> # receiving the response from "worker1" |
|
>>> cb_fut = fut.then(callback) |
|
>>> chain_cb_fut = cb_fut.then( |
|
... lambda x : print(f"Chained cb done. {x.wait()}") |
|
... ) |
|
>>> fut.set_result(5) |
|
RPC return value is 5. |
|
Chained cb done. None |
|
""" |
|
return cast(Future[S], super().then(callback)) |
|
|
|
def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None: |
|
r""" |
|
Append the given callback function to this ``Future``, which will be run |
|
when the ``Future`` is completed. Multiple callbacks can be added to |
|
the same ``Future``, but the order in which they will be executed cannot |
|
be guaranteed. The callback must take one argument, which is the |
|
reference to this ``Future``. The callback function can use the |
|
:meth:`value` method to get the value. Note that if this ``Future`` is |
|
already completed, the given callback will be run inline. |
|
|
|
We recommend that you use the :meth:`then` method as it provides a way |
|
to synchronize after your callback has completed. ``add_done_callback`` |
|
can be cheaper if your callback does not return anything. But both |
|
:meth:`then` and ``add_done_callback`` use the same callback |
|
registration API under the hood. |
|
|
|
With respect to GPU tensors, this method behaves in the same way as |
|
:meth:`then`. |
|
|
|
Args: |
|
callback(``Future``): a ``Callable`` that takes in one argument, |
|
which is the reference to this ``Future``. |
|
|
|
.. note:: Note that if the callback function throws, either |
|
through the original future being completed with an exception and |
|
calling ``fut.wait()``, or through other code in the callback, |
|
error handling must be carefully taken care of. For example, if |
|
this callback later completes additional futures, those futures are |
|
not marked as completed with an error and the user is responsible |
|
for handling completion/waiting on those futures independently. |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) |
|
>>> def callback(fut): |
|
... print("This will run after the future has finished.") |
|
... print(fut.wait()) |
|
>>> fut = torch.futures.Future() |
|
>>> fut.add_done_callback(callback) |
|
>>> fut.set_result(5) |
|
This will run after the future has finished. |
|
5 |
|
""" |
|
super().add_done_callback(callback) |
|
|
|
def set_result(self, result: T) -> None: |
|
r""" |
|
Set the result for this ``Future``, which will mark this ``Future`` as |
|
completed and trigger all attached callbacks. Note that a ``Future`` |
|
cannot be marked completed twice. |
|
|
|
If the result contains tensors that reside on GPUs, this method can be |
|
called even if the asynchronous kernels that are populating those |
|
tensors haven't yet completed running on the device, provided that the |
|
streams on which those kernels were enqueued are set as the current ones |
|
when this method is called. Put simply, it's safe to call this method |
|
immediately after launching those kernels, without any additional |
|
synchronization, as long as one doesn't change streams in between. This |
|
method will record events on all the relevant current streams and will |
|
use them to ensure proper scheduling for all the consumers of this |
|
``Future``. |
|
|
|
Args: |
|
result (object): the result object of this ``Future``. |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) |
|
>>> import threading |
|
>>> import time |
|
>>> def slow_set_future(fut, value): |
|
... time.sleep(0.5) |
|
... fut.set_result(value) |
|
>>> fut = torch.futures.Future() |
|
>>> t = threading.Thread( |
|
... target=slow_set_future, |
|
... args=(fut, torch.ones(2) * 3) |
|
... ) |
|
>>> t.start() |
|
>>> print(fut.wait()) |
|
tensor([3., 3.]) |
|
>>> t.join() |
|
""" |
|
super().set_result(result) |
|
|
|
def set_exception(self, result: T) -> None: |
|
r""" |
|
Set an exception for this ``Future``, which will mark this ``Future`` as |
|
completed with an error and trigger all attached callbacks. Note that |
|
when calling wait()/value() on this ``Future``, the exception set here |
|
will be raised inline. |
|
|
|
Args: |
|
result (BaseException): the exception for this ``Future``. |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) |
|
>>> fut = torch.futures.Future() |
|
>>> fut.set_exception(ValueError("foo")) |
|
>>> fut.wait() |
|
Traceback (most recent call last): |
|
... |
|
ValueError: foo |
|
""" |
|
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception." |
|
|
|
def raise_error(fut_result): |
|
raise fut_result |
|
|
|
super()._set_unwrap_func(raise_error) |
|
self.set_result(result) |
|
|
|
|
|
def collect_all(futures: List[Future]) -> Future[List[Future]]: |
|
r""" |
|
Collects the provided :class:`~torch.futures.Future` objects into a single |
|
combined :class:`~torch.futures.Future` that is completed when all of the |
|
sub-futures are completed. |
|
|
|
Args: |
|
futures (list): a list of :class:`~torch.futures.Future` objects. |
|
|
|
Returns: |
|
Returns a :class:`~torch.futures.Future` object to a list of the passed |
|
in Futures. |
|
|
|
Example:: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) |
|
>>> fut0 = torch.futures.Future() |
|
>>> fut1 = torch.futures.Future() |
|
>>> fut = torch.futures.collect_all([fut0, fut1]) |
|
>>> fut0.set_result(0) |
|
>>> fut1.set_result(1) |
|
>>> fut_list = fut.wait() |
|
>>> print(f"fut0 result = {fut_list[0].wait()}") |
|
fut0 result = 0 |
|
>>> print(f"fut1 result = {fut_list[1].wait()}") |
|
fut1 result = 1 |
|
""" |
|
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures))) |
|
|
|
|
|
def wait_all(futures: List[Future]) -> List: |
|
r""" |
|
Waits for all provided futures to be complete, and returns |
|
the list of completed values. If any of the futures encounters an error, |
|
the method will exit early and report the error not waiting for other |
|
futures to complete. |
|
|
|
Args: |
|
futures (list): a list of :class:`~torch.futures.Future` object. |
|
|
|
Returns: |
|
A list of the completed :class:`~torch.futures.Future` results. This |
|
method will throw an error if ``wait`` on any |
|
:class:`~torch.futures.Future` throws. |
|
""" |
|
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()] |
|
|