| from functools import partial |
| |
| from . import functions |
| from . import rpc_async |
| |
| import torch |
| from .constants import UNSET_RPC_TIMEOUT |
| from torch.futures import Future |
| |
| def _local_invoke(rref, func_name, args, kwargs): |
| return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| |
| @functions.async_execution |
| def _local_invoke_async_execution(rref, func_name, args, kwargs): |
| return getattr(rref.local_value(), func_name)(*args, **kwargs) |
| |
| def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): |
| def _rref_type_cont(rref_fut): |
| rref_type = rref_fut.value() |
| |
| _invoke_func = _local_invoke |
| # Bypass ScriptModules when checking for async function attribute. |
| bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( |
| rref_type, torch._C.ScriptModule |
| ) |
| if not bypass_type: |
| func = getattr(rref_type, func_name) |
| if hasattr(func, "_wrapped_async_rpc_function"): |
| _invoke_func = _local_invoke_async_execution |
| |
| return rpc_api( |
| rref.owner(), |
| _invoke_func, |
| args=(rref, func_name, args, kwargs), |
| timeout=timeout |
| ) |
| |
| rref_fut = rref._get_type(timeout=timeout, blocking=False) |
| |
| if rpc_api != rpc_async: |
| rref_fut.wait() |
| return _rref_type_cont(rref_fut) |
| else: |
| # A little explanation on this. |
| # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` |
| # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` |
| # To address that, we return a Future that is completed with the result of the async call. |
| result: Future = Future() |
| |
| def _wrap_rref_type_cont(fut): |
| try: |
| _rref_type_cont(fut).then(_complete_op) |
| except BaseException as ex: |
| result.set_exception(ex) |
| |
| def _complete_op(fut): |
| try: |
| result.set_result(fut.value()) |
| except BaseException as ex: |
| result.set_exception(ex) |
| |
| rref_fut.then(lambda fut: _wrap_rref_type_cont(fut)) |
| return result |
| |
| # This class manages proxied RPC API calls for RRefs. It is entirely used from |
| # C++ (see python_rpc_handler.cpp). |
| class RRefProxy: |
| def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): |
| self.rref = rref |
| self.rpc_api = rpc_api |
| self.rpc_timeout = timeout |
| |
| def __getattr__(self, func_name): |
| return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) |