blob: 102b5cc9fc38b66ebf81bc8d41495b1053f0ac39 [file] [log] [blame]
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/distributed/autograd/autograd.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
namespace torch {
namespace distributed {
namespace autograd {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
auto autograd_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
if (!autograd_module) {
throw python_error();
}
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule(
"_distributed_autograd", "distributed autograd bindings");
auto module = py::handle(m).cast<py::module>();
auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
.def(
"_context_id",
&DistAutogradContext::contextId,
py::call_guard<py::gil_scoped_release>())
.def(
"_recv_functions",
[](const DistAutogradContext& ctx) {
std::map<int64_t, py::object> funcs;
auto recvFunctions = ctx.recvFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : recvFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_send_functions",
[](const ContextPtr& ctx) {
std::map<int64_t, py::object> funcs;
auto sendFunctions = ctx->sendFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : sendFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_known_worker_ids",
&DistAutogradContext::getKnownWorkerIds,
py::call_guard<py::gil_scoped_release>());
module.def(
"_new_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().newContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_release_context",
[](int64_t context_id) {
return DistAutogradContainer::getInstance().releaseContext(context_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_max_id",
[]() { return DistAutogradContainer::getInstance().getMaxId(); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_is_valid_context",
[](int64_t worker_id) {
DistAutogradContainer::getInstance().isValidContext(worker_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_retrieve_context",
[](int64_t context_id) -> const ContextPtr {
return DistAutogradContainer::getInstance().retrieveContext(context_id);
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_current_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().currentContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_init",
[](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_debug_info",
[]() { return DistEngine::getInstance().getDebugInfo(); },
py::call_guard<py::gil_scoped_release>());
py::options options;
options.disable_function_signatures();
module.def(
"backward",
backward,
R"(
backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
Kicks off the distributed backward pass using the provided roots. This
currently implements the :ref:`fast-mode-algorithm` which
assumes all RPC messages sent in the same distributed autograd context
across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute
appropriate dependencies. This method blocks until the entire
autograd computation is done.
We accumulate the gradients in the appropriate
:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
context to be used is looked up given the ``context_id`` that is passed in when
:meth:`torch.distributed.autograd.backward` is called. If there is no valid
autograd context corresponding to the given ID, we throw an error. You can
retrieve the accumulated gradients using the
:meth:`~torch.distributed.autograd.get_gradients` API.
Arguments:
context_id (int): The autograd context id for which we should retrieve the gradients.
roots (list): Tensors which represent the roots of the autograd
computation. All the tensors should be scalars.
retain_graph(bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this
option to True is not needed and often can be worked around
in a much more efficient way. Usually, you need to set this
to True to run backward multiple times.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> pred = model.forward()
>>> loss = loss_func(pred, loss)
>>> dist_autograd.backward(context_id, loss)
)",
py::arg("contextId"),
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_gradients",
[](int64_t contextId) -> py::dict {
const auto& autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
auto ival = IValue(autogradContext->getGradients());
// Acquire GIL only for pyobject conversion.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(ival);
},
R"(
get_gradients(context_id: int) -> Dict[Tensor, Tensor]
Retrieves a map from Tensor to the appropriate gradient for that Tensor
accumulated in the provided context corresponding to the given ``context_id``
as part of the distributed autograd backward pass.
Arguments:
context_id(int): The autograd context id for which we should retrieve the
gradients.
Returns:
A map where the key is the Tensor and the value is the associated gradient
for that Tensor.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True)
>>> loss = t1 + t2
>>> dist_autograd.backward(context_id, [loss.sum()])
>>> grads = dist_autograd.get_gradients(context_id)
>>> print(grads[t1])
>>> print(grads[t2])
)",
py::arg("context_id"),
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace autograd
} // namespace distributed
} // namespace torch