|
|
|
__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", |
|
"rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] |
|
|
|
import collections |
|
import contextlib |
|
import functools |
|
import inspect |
|
import logging |
|
import threading |
|
from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING |
|
|
|
import torch |
|
from torch.futures import Future |
|
|
|
from torch._C._distributed_rpc import ( |
|
PyRRef, |
|
RemoteProfilerManager, |
|
WorkerInfo, |
|
TensorPipeAgent, |
|
get_rpc_timeout, |
|
_cleanup_python_rpc_handler, |
|
_delete_all_user_and_unforked_owner_rrefs, |
|
_destroy_rref_context, |
|
_get_current_rpc_agent, |
|
_invoke_remote_builtin, |
|
_invoke_remote_python_udf, |
|
_invoke_remote_torchscript, |
|
_invoke_rpc_builtin, |
|
_invoke_rpc_python_udf, |
|
_invoke_rpc_torchscript, |
|
_is_current_rpc_agent_set, |
|
_reset_current_rpc_agent, |
|
_set_and_start_rpc_agent, |
|
) |
|
|
|
from .internal import ( |
|
PythonUDF, |
|
RPCExecMode, |
|
_internal_rpc_pickler, |
|
_build_rpc_profiling_key, |
|
) |
|
|
|
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT |
|
|
|
from ._utils import _group_membership_management, _update_group_membership |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ignore_rref_leak = True |
|
_default_pickler = _internal_rpc_pickler |
|
|
|
@contextlib.contextmanager |
|
def _use_rpc_pickler(rpc_pickler): |
|
r""" |
|
rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler |
|
""" |
|
global _default_pickler |
|
_default_pickler = rpc_pickler |
|
try: |
|
yield |
|
finally: |
|
_default_pickler = _internal_rpc_pickler |
|
|
|
|
|
def _require_initialized(func): |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if not _is_current_rpc_agent_set(): |
|
raise RuntimeError( |
|
"RPC has not been initialized. Call " |
|
"torch.distributed.rpc.init_rpc first." |
|
) |
|
return func(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
class AllGatherStates: |
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.gathered_objects = {} |
|
|
|
|
|
self.proceed_signal = threading.Event() |
|
|
|
|
|
|
|
|
|
_ALL_WORKER_NAMES: Set[Any] = set() |
|
_all_gather_dict_lock = threading.RLock() |
|
_all_gather_sequence_id: Dict[str, int] = {} |
|
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) |
|
|
|
|
|
def _init_rpc_states(agent): |
|
worker_infos = agent.get_worker_infos() |
|
global _ALL_WORKER_NAMES |
|
_ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} |
|
|
|
|
|
if not _is_current_rpc_agent_set(): |
|
_set_and_start_rpc_agent(agent) |
|
|
|
|
|
def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): |
|
with _all_gather_dict_lock: |
|
if not worker_names: |
|
worker_names = _ALL_WORKER_NAMES |
|
assert ( |
|
worker_name in worker_names |
|
), f"{worker_name} is not expected by leader." |
|
states = _all_gather_sequence_id_to_states[sequence_id] |
|
assert ( |
|
worker_name not in states.gathered_objects |
|
), f"{worker_name} reported intent sequence id {sequence_id} twice. " |
|
states.gathered_objects[worker_name] = obj |
|
if worker_names == set(states.gathered_objects.keys()): |
|
states.proceed_signal.set() |
|
|
|
|
|
def _broadcast_to_followers(sequence_id, objects_map): |
|
with _all_gather_dict_lock: |
|
states = _all_gather_sequence_id_to_states[sequence_id] |
|
|
|
assert ( |
|
not states.proceed_signal.is_set() |
|
), f"Termination signal sequence id {sequence_id} got set twice." |
|
states.gathered_objects = objects_map |
|
states.proceed_signal.set() |
|
|
|
_thread_local_var = threading.local() |
|
|
|
|
|
@contextlib.contextmanager |
|
def _wait_all(): |
|
r""" |
|
A context manager that collects all futures returned by ``rpc_async`` and |
|
waits them on the context manager's exit; relieving the user of needing |
|
to explicitly call wait. |
|
|
|
|
|
Example:: |
|
>>> # xdoctest: +SKIP("distributed") |
|
>>> # On worker 0: |
|
>>> import torch |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> with rpc._wait_all(): |
|
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) |
|
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) |
|
>>> #fut_1 and fut_2 are waited on |
|
""" |
|
_thread_local_var.future_list = [] |
|
try: |
|
yield |
|
finally: |
|
try: |
|
torch.futures.wait_all(_thread_local_var.future_list) |
|
finally: |
|
del _thread_local_var.future_list |
|
|
|
|
|
@_require_initialized |
|
def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): |
|
r""" |
|
This is similar to torch.distributed.all_gather(), but is using RPC. It |
|
picks the worker with the smallest name (alphabetic order) as the leader. |
|
Then all followers send their data ``obj`` to the leader. After the leader |
|
has received all, it will broadcast the results back to all followers. This |
|
function blocks until all workers have received the gathered results. |
|
""" |
|
if not worker_names: |
|
assert ( |
|
_ALL_WORKER_NAMES is not None |
|
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." |
|
worker_names = _ALL_WORKER_NAMES |
|
leader_name = min(worker_names) |
|
|
|
self_name = _get_current_rpc_agent().get_worker_info().name |
|
|
|
with _all_gather_dict_lock: |
|
concat_names = "".join(sorted(worker_names)) |
|
sequence_num = _all_gather_sequence_id.get(concat_names, 0) |
|
_all_gather_sequence_id[concat_names] = sequence_num + 1 |
|
sequence_id = concat_names + str(sequence_num) |
|
|
|
is_leader = leader_name == self_name |
|
|
|
if timeout == UNSET_RPC_TIMEOUT: |
|
|
|
rpc_timeout = get_rpc_timeout() |
|
|
|
signal_timeout = None |
|
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: |
|
|
|
rpc_timeout = timeout |
|
|
|
signal_timeout = None |
|
else: |
|
|
|
signal_timeout = rpc_timeout = timeout |
|
|
|
|
|
if is_leader: |
|
_gather_to_leader(sequence_id, self_name, obj, worker_names) |
|
else: |
|
rpc_sync( |
|
leader_name, |
|
_gather_to_leader, |
|
args=(sequence_id, self_name, obj, worker_names), |
|
timeout=rpc_timeout, |
|
) |
|
|
|
with _all_gather_dict_lock: |
|
states = _all_gather_sequence_id_to_states[sequence_id] |
|
|
|
|
|
states.proceed_signal.wait(timeout=signal_timeout) |
|
|
|
|
|
|
|
|
|
if is_leader: |
|
worker_name_to_response_future_dict = {} |
|
for follower_name in worker_names - {leader_name}: |
|
fut = rpc_async( |
|
follower_name, |
|
_broadcast_to_followers, |
|
args=(sequence_id, states.gathered_objects), |
|
timeout=rpc_timeout |
|
) |
|
worker_name_to_response_future_dict[follower_name] = fut |
|
|
|
errors = [] |
|
for follower_name, fut in worker_name_to_response_future_dict.items(): |
|
try: |
|
fut.wait() |
|
except RuntimeError as ex: |
|
errors.append((follower_name, ex)) |
|
|
|
if errors: |
|
raise RuntimeError( |
|
f"Followers {[e[0] for e in errors]} timed out in _all_gather " |
|
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" |
|
) |
|
|
|
|
|
with _all_gather_dict_lock: |
|
states = _all_gather_sequence_id_to_states.pop(sequence_id) |
|
return states.gathered_objects |
|
|
|
|
|
@_require_initialized |
|
def _barrier(worker_names): |
|
r""" |
|
Synchronizes local and remote RPC processes. |
|
|
|
This will block until all local and remote RPC processes specified under worker_names |
|
reach this method to wait for all outstanding work to complete. |
|
|
|
Args: |
|
worker_names (List[str]): The set of workers to synchronize. |
|
|
|
""" |
|
try: |
|
_all_gather(None, set(worker_names)) |
|
except RuntimeError as ex: |
|
logger.error( |
|
"Failed to complete barrier, got error %s", ex |
|
) |
|
|
|
|
|
@_require_initialized |
|
def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): |
|
r""" |
|
Block until all local and remote RPC processes reach this method and wait |
|
for all outstanding work to complete. Every RPC process must call this |
|
method before exit to perform a graceful shutdown. This should be used to |
|
terminate the RPC framework, and there is no guarantee that the RPC |
|
framework will work after this method returns. |
|
""" |
|
try: |
|
_all_gather(None, timeout=timeout) |
|
except RuntimeError as ex: |
|
logger.error( |
|
"Failed to respond to 'Shutdown Proceed' in time, got error %s", ex |
|
) |
|
raise ex |
|
|
|
|
|
@_require_initialized |
|
def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): |
|
r""" |
|
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This |
|
stops the local agent from accepting outstanding requests, and shuts |
|
down the RPC framework by terminating all RPC threads. If ``graceful=True``, |
|
this will block until all local and remote RPC processes reach this method |
|
and wait for all outstanding work to complete. Otherwise, if |
|
``graceful=False``, this is a local shutdown, and it does not wait for other |
|
RPC processes to reach this method. |
|
|
|
.. warning:: |
|
For :class:`~torch.futures.Future` objects returned by |
|
:meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not |
|
be called after ``shutdown()``. |
|
|
|
Args: |
|
graceful (bool): Whether to do a graceful shutdown or not. If True, |
|
this will 1) wait until there is no pending system |
|
messages for ``UserRRefs`` and delete them; 2) block |
|
until all local and remote RPC processes have reached |
|
this method and wait for all outstanding work to |
|
complete. |
|
|
|
Example:: |
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
|
API for more details. For example, |
|
|
|
export MASTER_ADDR=localhost |
|
export MASTER_PORT=5678 |
|
|
|
Then run the following code in two different processes: |
|
|
|
>>> # xdoctest: +SKIP |
|
>>> # On worker 0: |
|
>>> import torch |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> # do some work |
|
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) |
|
>>> # ready to shutdown |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> # wait for worker 0 to finish work, and then shutdown. |
|
>>> rpc.shutdown() |
|
""" |
|
if graceful: |
|
try: |
|
agent = _get_current_rpc_agent() |
|
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: |
|
_wait_all_workers(timeout) |
|
_delete_all_user_and_unforked_owner_rrefs() |
|
agent.join(shutdown=True, timeout=timeout) |
|
else: |
|
|
|
my_worker_info = agent.get_worker_info() |
|
my_name = my_worker_info.name |
|
with _group_membership_management(agent.store, my_name, False): |
|
all_worker_infos = agent.get_worker_infos() |
|
for worker in all_worker_infos: |
|
if worker.name != my_name: |
|
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) |
|
agent.join(shutdown=True, timeout=timeout) |
|
finally: |
|
|
|
_finalize_shutdown() |
|
else: |
|
_finalize_shutdown() |
|
|
|
|
|
def _finalize_shutdown(): |
|
try: |
|
|
|
_destroy_rref_context(_ignore_rref_leak) |
|
finally: |
|
_get_current_rpc_agent().shutdown() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_cleanup_python_rpc_handler() |
|
_reset_current_rpc_agent() |
|
|
|
|
|
@_require_initialized |
|
def get_worker_info(worker_name=None): |
|
r""" |
|
Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. |
|
Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an |
|
expensive string on every invocation. |
|
|
|
Args: |
|
worker_name (str): the string name of a worker. If ``None``, return the |
|
the id of the current worker. (default ``None``) |
|
|
|
Returns: |
|
:class:`~torch.distributed.rpc.WorkerInfo` instance for the given |
|
``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the |
|
current worker if ``worker_name`` is ``None``. |
|
""" |
|
if worker_name is not None: |
|
return _get_current_rpc_agent().get_worker_info(worker_name) |
|
else: |
|
return _get_current_rpc_agent().get_worker_info() |
|
|
|
|
|
def _to_worker_info(to): |
|
if isinstance(to, WorkerInfo): |
|
return to |
|
elif isinstance(to, (str, int)): |
|
return get_worker_info(to) |
|
else: |
|
raise ValueError(f"Cannot get WorkerInfo from name {to}") |
|
|
|
|
|
def _rref_typeof_on_owner(rref, blocking: bool = True): |
|
rref_type = type(rref.local_value()) |
|
if blocking: |
|
return rref_type |
|
else: |
|
|
|
|
|
|
|
future = Future[type]() |
|
future.set_result(rref_type) |
|
return future |
|
|
|
|
|
def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): |
|
fut = rpc_async( |
|
rref.owner(), |
|
_rref_typeof_on_owner, |
|
args=(rref,), |
|
timeout=timeout |
|
) |
|
if blocking: |
|
return fut.wait() |
|
else: |
|
return fut |
|
|
|
|
|
T = TypeVar("T") |
|
GenericWithOneTypeVar = Generic[T] |
|
|
|
|
|
if TYPE_CHECKING: |
|
class RRef(PyRRef[T], Generic[T]): |
|
pass |
|
else: |
|
try: |
|
|
|
class RRef(PyRRef, Generic[T]): |
|
pass |
|
except TypeError: |
|
|
|
|
|
|
|
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): |
|
pass |
|
|
|
|
|
|
|
class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def method_factory(method_name, docstring): |
|
def method(self, *args, **kwargs): |
|
return getattr(super(RRef, self), method_name)(*args, **kwargs) |
|
|
|
if method.__doc__: |
|
method.__doc__ = docstring |
|
return method |
|
|
|
|
|
for method_name, method in inspect.getmembers(PyRRef): |
|
|
|
if method_name.startswith("_") and method_name != "__str__": |
|
continue |
|
|
|
|
|
|
|
""" |
|
to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object |
|
|
|
Blocking call that copies the value of the RRef from the owner |
|
to the local node and returns it. If the current node is the |
|
owner, returns a reference to the local value. |
|
""" |
|
docstring = getattr(method, "__doc__", None) |
|
assert docstring is not None, "RRef user-facing methods should all have docstrings." |
|
|
|
|
|
docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") |
|
|
|
|
|
new_method = method_factory(method_name, docstring) |
|
setattr(RRef, method_name, new_method) |
|
|
|
|
|
@_require_initialized |
|
def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): |
|
r""" |
|
Make a remote call to run ``func`` on worker ``to`` and return an |
|
:class:`~torch.distributed.rpc.RRef` to the result value immediately. |
|
Worker ``to`` will be the owner of the returned |
|
:class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is |
|
a user. The owner manages the global reference count of its |
|
:class:`~torch.distributed.rpc.RRef`, and the owner |
|
:class:`~torch.distributed.rpc.RRef` is only destructed when globally there |
|
are no living references to it. |
|
|
|
Args: |
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
|
func (Callable): a callable function, such as Python callables, builtin |
|
operators (e.g. :meth:`~torch.add`) and annotated |
|
TorchScript functions. |
|
args (tuple): the argument tuple for the ``func`` invocation. |
|
kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
|
invocation. |
|
|
|
timeout (float, optional): timeout in seconds for this remote call. If the |
|
creation of this |
|
:class:`~torch.distributed.rpc.RRef` on worker |
|
``to`` is not successfully processed on this |
|
worker within this timeout, then the next time |
|
there is an attempt to use the RRef (such as |
|
``to_here()``), a timeout will be raised |
|
indicating this failure. A value of 0 indicates |
|
an infinite timeout, i.e. a timeout error will |
|
never be raised. If not provided, the default |
|
value set during initialization or with |
|
``_set_rpc_timeout`` is used. |
|
|
|
Returns: |
|
A user :class:`~torch.distributed.rpc.RRef` instance to the result |
|
value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` |
|
to retrieve the result value locally. |
|
|
|
.. warning :: |
|
The ``remote`` API does not copy storages of argument tensors until |
|
sending them over the wire, which could be done by a different thread |
|
depending on the RPC backend type. The caller should make sure that the |
|
contents of those tensors stay intact until the returned RRef is |
|
confirmed by the owner, which can be checked using the |
|
:meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. |
|
|
|
.. warning :: |
|
Errors such as timeouts for the ``remote`` API are handled on a |
|
best-effort basis. This means that when remote calls initiated by |
|
``remote`` fail, such as with a timeout error, we take a best-effort |
|
approach to error handling. This means that errors are handled and set |
|
on the resulting RRef on an asynchronous basis. If the RRef has not been |
|
used by the application before this handling (such as ``to_here`` or |
|
fork call), then future uses of the ``RRef`` will appropriately raise |
|
errors. However, it is possible that the user application will use the |
|
``RRef`` before the errors are handled. In this case, errors may not be |
|
raised as they have not yet been handled. |
|
|
|
Example:: |
|
|
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
|
API for more details. For example, |
|
|
|
export MASTER_ADDR=localhost |
|
export MASTER_PORT=5678 |
|
|
|
Then run the following code in two different processes: |
|
|
|
>>> # xdoctest: +SKIP |
|
>>> # On worker 0: |
|
>>> import torch |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) |
|
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) |
|
>>> x = rref1.to_here() + rref2.to_here() |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
|
|
Below is an example of running a TorchScript function using RPC. |
|
|
|
>>> # On both workers: |
|
>>> @torch.jit.script |
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int): |
|
>>> return torch.add(tensor, scalar) |
|
|
|
>>> # On worker 0: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) |
|
>>> rref.to_here() |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
""" |
|
torch._C._log_api_usage_once("torch.distributed.rpc_remote") |
|
qualified_name = torch.jit._builtins._find_builtin(func) |
|
dst_worker_info = _to_worker_info(to) |
|
should_profile = _get_should_profile() |
|
|
|
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) |
|
|
|
with ctx_manager as rf: |
|
args = args if args else () |
|
kwargs = kwargs if kwargs else {} |
|
|
|
is_async_exec = hasattr(func, "_wrapped_async_rpc_function") |
|
|
|
if is_async_exec: |
|
wrapped = func._wrapped_async_rpc_function |
|
if isinstance(wrapped, torch.jit.ScriptFunction): |
|
func = wrapped |
|
|
|
if qualified_name is not None: |
|
rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) |
|
elif isinstance(func, torch.jit.ScriptFunction): |
|
rref = _invoke_remote_torchscript( |
|
dst_worker_info.name, |
|
torch._jit_internal._qualified_name(func), |
|
timeout, |
|
is_async_exec, |
|
*args, |
|
**kwargs, |
|
) |
|
else: |
|
(pickled_python_udf, tensors) = _default_pickler.serialize( |
|
PythonUDF(func, args, kwargs) |
|
) |
|
rref = _invoke_remote_python_udf( |
|
dst_worker_info, |
|
pickled_python_udf, |
|
tensors, |
|
timeout, |
|
is_async_exec |
|
) |
|
|
|
if should_profile: |
|
assert torch.autograd._profiler_enabled() |
|
assert rf is not None |
|
fut = rf._call_end_callbacks_on_future(rref._get_future()) |
|
rref._set_profiling_future(fut) |
|
|
|
return rref |
|
|
|
|
|
def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): |
|
if not callable(func): |
|
raise TypeError("function should be callable.") |
|
|
|
qualified_name = torch.jit._builtins._find_builtin(func) |
|
dst_worker_info = _to_worker_info(to) |
|
|
|
should_profile = _get_should_profile() |
|
|
|
ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) |
|
|
|
with ctx_manager as rf: |
|
args = args if args else () |
|
kwargs = kwargs if kwargs else {} |
|
|
|
is_async_exec = hasattr(func, "_wrapped_async_rpc_function") |
|
|
|
if is_async_exec: |
|
wrapped = func._wrapped_async_rpc_function |
|
if isinstance(wrapped, torch.jit.ScriptFunction): |
|
func = wrapped |
|
|
|
if qualified_name is not None: |
|
fut = _invoke_rpc_builtin( |
|
dst_worker_info, |
|
qualified_name, |
|
rpc_timeout, |
|
*args, |
|
**kwargs |
|
) |
|
elif isinstance(func, torch.jit.ScriptFunction): |
|
fut = _invoke_rpc_torchscript( |
|
dst_worker_info.name, |
|
torch._jit_internal._qualified_name(func), |
|
args, |
|
kwargs, |
|
rpc_timeout, |
|
is_async_exec |
|
) |
|
else: |
|
(pickled_python_udf, tensors) = _default_pickler.serialize( |
|
PythonUDF(func, args, kwargs) |
|
) |
|
fut = _invoke_rpc_python_udf( |
|
dst_worker_info, |
|
pickled_python_udf, |
|
tensors, |
|
rpc_timeout, |
|
is_async_exec |
|
) |
|
if should_profile: |
|
assert torch.autograd._profiler_enabled() |
|
assert rf is not None |
|
|
|
|
|
|
|
|
|
|
|
fut = rf._call_end_callbacks_on_future(fut) |
|
return fut |
|
|
|
|
|
@_require_initialized |
|
def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): |
|
r""" |
|
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC |
|
messages are sent and received in parallel to execution of Python code. This |
|
method is thread-safe. |
|
|
|
Args: |
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
|
func (Callable): a callable function, such as Python callables, builtin |
|
operators (e.g. :meth:`~torch.add`) and annotated |
|
TorchScript functions. |
|
args (tuple): the argument tuple for the ``func`` invocation. |
|
kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
|
invocation. |
|
timeout (float, optional): timeout in seconds to use for this RPC. If |
|
the RPC does not complete in this amount of |
|
time, an exception indicating it has |
|
timed out will be raised. A value of 0 |
|
indicates an infinite timeout, i.e. a timeout |
|
error will never be raised. If not provided, |
|
the default value set during initialization |
|
or with ``_set_rpc_timeout`` is used. |
|
|
|
Returns: |
|
Returns the result of running ``func`` with ``args`` and ``kwargs``. |
|
|
|
Example:: |
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
|
API for more details. For example, |
|
|
|
export MASTER_ADDR=localhost |
|
export MASTER_PORT=5678 |
|
|
|
Then run the following code in two different processes: |
|
|
|
>>> # xdoctest: +SKIP |
|
>>> # On worker 0: |
|
>>> import torch |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
|
|
Below is an example of running a TorchScript function using RPC. |
|
|
|
>>> # On both workers: |
|
>>> @torch.jit.script |
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int): |
|
>>> return torch.add(tensor, scalar) |
|
|
|
>>> # On worker 0: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
|
|
""" |
|
torch._C._log_api_usage_once("torch.distributed.rpc_sync") |
|
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) |
|
return fut.wait() |
|
|
|
|
|
@_require_initialized |
|
def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): |
|
r""" |
|
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC |
|
messages are sent and received in parallel to execution of Python code. This |
|
method is thread-safe. This method will immediately return a |
|
:class:`~torch.futures.Future` that can be awaited on. |
|
|
|
Args: |
|
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. |
|
func (Callable): a callable function, such as Python callables, builtin |
|
operators (e.g. :meth:`~torch.add`) and annotated |
|
TorchScript functions. |
|
args (tuple): the argument tuple for the ``func`` invocation. |
|
kwargs (dict): is a dictionary of keyword arguments for the ``func`` |
|
invocation. |
|
timeout (float, optional): timeout in seconds to use for this RPC. If |
|
the RPC does not complete in this amount of |
|
time, an exception indicating it has |
|
timed out will be raised. A value of 0 |
|
indicates an infinite timeout, i.e. a timeout |
|
error will never be raised. If not provided, |
|
the default value set during initialization |
|
or with ``_set_rpc_timeout`` is used. |
|
|
|
|
|
Returns: |
|
Returns a :class:`~torch.futures.Future` object that can be waited |
|
on. When completed, the return value of ``func`` on ``args`` and |
|
``kwargs`` can be retrieved from the :class:`~torch.futures.Future` |
|
object. |
|
|
|
.. warning :: |
|
Using GPU tensors as arguments or return values of ``func`` is not |
|
supported since we don't support sending GPU tensors over the wire. You |
|
need to explicitly copy GPU tensors to CPU before using them as |
|
arguments or return values of ``func``. |
|
|
|
.. warning :: |
|
The ``rpc_async`` API does not copy storages of argument tensors until |
|
sending them over the wire, which could be done by a different thread |
|
depending on the RPC backend type. The caller should make sure that the |
|
contents of those tensors stay intact until the returned |
|
:class:`~torch.futures.Future` completes. |
|
|
|
Example:: |
|
Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly |
|
on both workers. Refer to :meth:`~torch.distributed.init_process_group` |
|
API for more details. For example, |
|
|
|
export MASTER_ADDR=localhost |
|
export MASTER_PORT=5678 |
|
|
|
Then run the following code in two different processes: |
|
|
|
>>> # xdoctest: +SKIP |
|
>>> # On worker 0: |
|
>>> import torch |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) |
|
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) |
|
>>> result = fut1.wait() + fut2.wait() |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
|
|
Below is an example of running a TorchScript function using RPC. |
|
|
|
>>> # On both workers: |
|
>>> @torch.jit.script |
|
>>> def my_script_add(tensor: torch.Tensor, scalar: int): |
|
>>> return torch.add(tensor, scalar) |
|
|
|
>>> # On worker 0: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
>>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) |
|
>>> ret = fut.wait() |
|
>>> rpc.shutdown() |
|
|
|
>>> # On worker 1: |
|
>>> import torch.distributed.rpc as rpc |
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
>>> rpc.shutdown() |
|
""" |
|
torch._C._log_api_usage_once("torch.distributed.rpc_async") |
|
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) |
|
if hasattr(_thread_local_var, "future_list"): |
|
_thread_local_var.future_list.append(fut) |
|
return fut |
|
|
|
|
|
def _get_should_profile(): |
|
|
|
|
|
ActiveProfilerType = torch._C._profiler.ActiveProfilerType |
|
return ( |
|
torch.autograd._profiler_enabled() and |
|
torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY |
|
) |
|
|
|
|
|
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): |
|
ctx_manager = contextlib.nullcontext() |
|
|
|
if should_profile: |
|
|
|
|
|
if qualified_name is None: |
|
func_name = ( |
|
torch._jit_internal._qualified_name(func) |
|
if isinstance(func, torch.jit.ScriptFunction) |
|
else func.__qualname__ |
|
) |
|
else: |
|
func_name = qualified_name |
|
|
|
rpc_profiling_key = _build_rpc_profiling_key( |
|
rpc_type, |
|
func_name, |
|
get_worker_info().name, |
|
dst_worker_info.name, |
|
) |
|
RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) |
|
|
|
ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) |
|
|
|
return ctx_manager |
|
|