| #include <torch/csrc/python_headers.h> |
| |
| #include <c10/util/intrusive_ptr.h> |
| #include <c10/util/string_view.h> |
| #include <torch/csrc/distributed/c10d/FileStore.hpp> |
| #include <torch/csrc/distributed/c10d/TCPStore.hpp> |
| #include <torch/csrc/distributed/c10d/Utils.hpp> |
| #ifndef _WIN32 |
| #include <torch/csrc/distributed/c10d/HashStore.hpp> |
| #include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp> |
| #endif |
| #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
| #include <torch/csrc/distributed/c10d/PyProcessGroup.hpp> |
| |
| #ifdef USE_C10D_GLOO |
| #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
| #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp> |
| #endif |
| |
| #ifdef USE_C10D_NCCL |
| #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
| #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
| #endif |
| |
| #ifdef USE_C10D_MPI |
| #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp> |
| #endif |
| |
| #ifdef USE_C10D_UCC |
| #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp> |
| #endif |
| |
| #include <fmt/format.h> |
| #include <pybind11/chrono.h> |
| #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
| |
| #include <torch/csrc/distributed/c10d/comm.hpp> |
| #include <torch/csrc/distributed/c10d/debug.h> |
| #include <torch/csrc/distributed/c10d/logger.hpp> |
| #include <torch/csrc/distributed/c10d/reducer.hpp> |
| |
| #include <torch/csrc/Exceptions.h> |
| #include <torch/csrc/distributed/c10d/python_comm_hook.h> |
| #include <torch/csrc/jit/python/pybind_utils.h> |
| #include <torch/csrc/utils/object_ptr.h> |
| #include <torch/csrc/utils/pybind.h> |
| |
| #include <torch/custom_class.h> |
| |
| namespace { |
| |
| // Wrapper to ensure GIL is released before destructing ProcessGroupGloo |
| // TODO: move this somewhere more generally useful |
| template <typename T> |
| class IntrusivePtrNoGilDestructor { |
| c10::intrusive_ptr<T> impl_; |
| |
| public: |
| IntrusivePtrNoGilDestructor() = default; |
| IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default; |
| IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default; |
| IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) = |
| default; |
| IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) = |
| default; |
| /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr<T> impl) |
| : impl_(std::move(impl)) {} |
| // This ctor is very important; see |
| // https://github.com/pybind/pybind11/issues/2957 |
| explicit IntrusivePtrNoGilDestructor(T* impl) |
| : impl_(c10::intrusive_ptr<T>::unsafe_steal_from_new(impl)) {} |
| ~IntrusivePtrNoGilDestructor() { |
| if (impl_) { |
| if (PyGILState_Check()) { |
| pybind11::gil_scoped_release release; |
| impl_.reset(); |
| } else { |
| impl_.reset(); |
| } |
| } |
| } |
| T& operator*() const noexcept { |
| return *impl_; |
| } |
| T* operator->() const noexcept { |
| return impl_.get(); |
| } |
| C10_NODISCARD T* get() const noexcept { |
| return impl_.get(); |
| } |
| void reset() noexcept { |
| impl_.reset(); |
| } |
| operator bool() const noexcept { |
| return impl_; |
| } |
| }; |
| |
| } // anonymous namespace |
| |
| PYBIND11_DECLARE_HOLDER_TYPE(T, IntrusivePtrNoGilDestructor<T>, true); |
| |
| namespace torch { |
| namespace distributed { |
| namespace c10d { |
| |
| namespace { |
| |
| template <typename T> |
| using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
| |
| constexpr auto kDeprecationWarning = |
| "{} API is being deprecated, please ping " |
| "https://github.com/pytorch/pytorch/issues/46291 " |
| "if you see this warning"; |
| template <typename T> |
| using intrusive_ptr_class_ = py::class_<T, c10::intrusive_ptr<T>>; |
| |
| template <typename T> |
| using intrusive_ptr_no_gil_destructor_class_ = |
| py::class_<T, IntrusivePtrNoGilDestructor<T>>; |
| |
| // PythonStore is a pybind11 trampoline class to allow a Python |
| // class to inherit from c10d.Store and implement its interface. |
| class PythonStore : public ::c10d::Store { |
| public: |
| using ::c10d::Store::Store; |
| |
| // Note: this function manually calls the Python-side overload |
| // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
| // macros. This is done so that we can call the Python-side |
| // function with a std::string instead of a std::vector<uint8_t>. |
| void set(const std::string& key, const std::vector<uint8_t>& value) override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = |
| pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set"); |
| TORCH_INTERNAL_ASSERT(fn, "Not implemented."); |
| // Call function with a py::bytes object for the value. |
| fn(key, |
| py::bytes(reinterpret_cast<const char*>(value.data()), value.size())); |
| } |
| |
| // Note: this function manually calls the Python-side overload |
| // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
| // macros. This is done so that the Python-side function can |
| // return a py::bytes instead of a std::vector<uint8_t>. |
| std::vector<uint8_t> get(const std::string& key) override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = |
| pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "get"); |
| TORCH_INTERNAL_ASSERT(fn, "Not implemented."); |
| // Cast return value from Python to py::bytes, then implicitly |
| // convert that to a std::string, so that we can construct a |
| // std::vector<uint8_t>. There is no API for directly accessing |
| // the contents of the py::bytes object. |
| std::string str = pybind11::cast<py::bytes>(fn(key)); |
| return std::vector<uint8_t>(str.begin(), str.end()); |
| } |
| |
| // Note: this function manually calls the Python-side overload |
| // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
| // macros. This is done so that the Python-side function can |
| // return a py::bytes instead of a std::vector<uint8_t>. |
| std::vector<uint8_t> compareSet( |
| const std::string& key, |
| const std::vector<uint8_t>& expectedValue, |
| const std::vector<uint8_t>& desiredValue) override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = pybind11::get_overload( |
| static_cast<const ::c10d::Store*>(this), "compare_set"); |
| TORCH_INTERNAL_ASSERT(fn, "Not implemented."); |
| // Cast return value from Python to py::bytes, then implicitly |
| // convert that to a std::string, so that we can construct a |
| // std::vector<uint8_t>. There is no API for directly accessing |
| // the contents of the py::bytes object. |
| std::string str = pybind11::cast<py::bytes>( |
| fn(key, |
| py::bytes( |
| reinterpret_cast<const char*>(expectedValue.data()), |
| expectedValue.size()), |
| py::bytes( |
| reinterpret_cast<const char*>(desiredValue.data()), |
| desiredValue.size()))); |
| return std::vector<uint8_t>(str.begin(), str.end()); |
| } |
| |
| int64_t add(const std::string& key, int64_t value) override { |
| PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value); |
| } |
| |
| int64_t getNumKeys() override { |
| PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys); |
| } |
| |
| bool deleteKey(const std::string& key) override { |
| PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key); |
| } |
| |
| bool check(const std::vector<std::string>& keys) override { |
| PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys); |
| } |
| |
| void wait(const std::vector<std::string>& keys) override { |
| PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys); |
| } |
| |
| void wait( |
| const std::vector<std::string>& keys, |
| const std::chrono::milliseconds& timeout) override { |
| PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout); |
| } |
| |
| // Note: this function manually calls the Python-side overload |
| // for this function instead of using the PYBIND11_OVERLOAD_XYZ |
| // macros. This is done so that we can call the Python-side |
| // function with a std::string instead of a std::vector<uint8_t>. |
| void append(const std::string& key, const std::vector<uint8_t>& value) |
| override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = pybind11::get_overload( |
| static_cast<const ::c10d::Store*>(this), "append"); |
| if (!fn) { |
| return Store::append(key, value); |
| } |
| // Call function with a py::bytes object for the value. |
| fn(key, |
| py::bytes(reinterpret_cast<const char*>(value.data()), value.size())); |
| } |
| |
| virtual std::vector<std::vector<uint8_t>> multiGet( |
| const std::vector<std::string>& keys) override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = pybind11::get_overload( |
| static_cast<const ::c10d::Store*>(this), "multi_get"); |
| if (!fn) { |
| return Store::multiGet(keys); |
| } |
| std::vector<std::string> py_list = |
| pybind11::cast<std::vector<std::string>>(fn(keys)); |
| std::vector<std::vector<uint8_t>> res; |
| |
| for (auto& str : py_list) { |
| res.emplace_back(std::vector<uint8_t>(str.begin(), str.end())); |
| } |
| |
| return res; |
| } |
| |
| virtual void multiSet( |
| const std::vector<std::string>& keys, |
| const std::vector<std::vector<uint8_t>>& values) override { |
| pybind11::gil_scoped_acquire gil; |
| pybind11::function fn = pybind11::get_overload( |
| static_cast<const ::c10d::Store*>(this), "multi_set"); |
| if (!fn) { |
| return Store::multiSet(keys, values); |
| } |
| |
| std::vector<py::bytes> bytes; |
| for (auto& value : values) { |
| bytes.emplace_back( |
| py::bytes(reinterpret_cast<const char*>(value.data()), value.size())); |
| } |
| |
| fn(keys, bytes); |
| } |
| |
| bool hasExtendedApi() const override { |
| PYBIND11_OVERLOAD_NAME( |
| bool, ::c10d::Store, "has_extended_api", hasExtendedApi); |
| } |
| }; |
| |
| // Called from DDP's Python API to create a c10d Python comm hook object. |
| // The input state and callable comm_hook are Python objects. It later calls |
| // register_comm_hook function of the reducer input to register the hook. |
| void _register_comm_hook( |
| ::c10d::Reducer& reducer, |
| py::object state, |
| py::object comm_hook) { |
| reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>( |
| std::move(state), std::move(comm_hook))); |
| } |
| |
| // Called from DDP's Python API to create a c10d C++ comm hook. |
| // The input is an enum hook type. It later calls register_builtin_comm_hook |
| // function of the reducer input to set the hook type. |
| void _register_builtin_comm_hook( |
| ::c10d::Reducer& reducer, |
| ::c10d::BuiltinCommHookType comm_hook_type) { |
| reducer.register_builtin_comm_hook(comm_hook_type); |
| } |
| |
| // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility. |
| // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to |
| // struct from enum, sacrificing some of the Python built-in function supports |
| // such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191) |
| // and `copy` (see |
| // https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below, |
| // we define a custom `isinstance` in CPython/pybind11 |
| // (`reduceopmeta___instancecheck__`) and modify the default metaclass of |
| // pybind11 (`GetReduceOpMetaclass`) so that |
| // `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)` |
| // returns :obj:`True` as if `ReduceOp` is enum. |
| // Ref: |
| // - https://docs.python.org/3/extending/newtypes_tutorial.html |
| // - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods |
| // - https://github.com/pybind/pybind11/issues/2696 |
| static PyObject* reduceopmeta___instancecheck__( |
| PyObject* self, |
| PyObject* args) { |
| if (Py_TYPE(self) == Py_TYPE(args)) { |
| Py_RETURN_TRUE; |
| } |
| if (c10::string_view(args->ob_type->tp_name).find("RedOpType") != |
| c10::string_view::npos) { |
| Py_RETURN_TRUE; |
| } |
| Py_RETURN_FALSE; |
| } |
| static PyMethodDef reduceopmeta_methods[] = { |
| {"__instancecheck__", |
| (PyCFunction)reduceopmeta___instancecheck__, |
| METH_O, |
| "Custom `__instancecheck__` for ReduceOp"}, |
| {nullptr, nullptr}}; |
| PyTypeObject* GetReduceOpMetaclass() { |
| static auto* metaclass = [] { |
| PyTypeObject* base_metaclass = |
| pybind11::detail::get_internals().default_metaclass; |
| PyType_Slot slots[] = { |
| {Py_tp_base, base_metaclass}, |
| {Py_tp_methods, reduceopmeta_methods}, |
| {0}, |
| }; |
| PyType_Spec spec = {}; |
| spec.name = "torch._C._distributed_c10d._ReduceOpMeta"; |
| spec.basicsize = base_metaclass->tp_basicsize; |
| spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; |
| spec.slots = slots; |
| PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec); |
| if (!metaclass) |
| throw py::error_already_set(); |
| return metaclass; |
| }(); |
| return metaclass; |
| } |
| |
| PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { |
| C10_LOG_API_USAGE_ONCE("c10d.python.import"); |
| |
| auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed")); |
| if (!c10d_module) { |
| throw python_error(); |
| } |
| |
| auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); |
| if (!torch_C_module) { |
| throw python_error(); |
| } |
| |
| auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); |
| auto m = |
| torch_C_m.def_submodule("_distributed_c10d", "distributed c10d bindings"); |
| |
| auto module = py::handle(m).cast<py::module>(); |
| |
| module |
| .def( |
| "_register_comm_hook", |
| &_register_comm_hook, |
| py::arg("reducer"), |
| py::arg("state"), |
| py::arg("comm_hook"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_register_builtin_comm_hook", |
| &_register_builtin_comm_hook, |
| py::arg("reducer"), |
| py::arg("comm_hook_type")); |
| |
| shared_ptr_class_<::c10d::GradBucket>( |
| module, |
| "GradBucket", |
| R"( |
| This class mainly passes a flattened gradient tensor |
| (returned by :meth:`~torch.distributed.GradBucket.buffer`) |
| to DDP communication hook. |
| This tensor can be further decomposed into a list of per-parameter tensors within this bucket |
| (returned by :meth:`~torch.distributed.GradBucket.get_per_parameter_tensors`) |
| to apply layer-wise operations. |
| )") |
| .def( |
| "index", |
| &::c10d::GradBucket::getIndex, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| .. warning:: |
| Since the buckets are rebuilt after the first iteration, should not rely on the indices at the beginning of training. |
| |
| Returns: |
| The index of a bucket that stores gradients of a few contiguous layers. |
| All the gradients are bucketized. |
| )") |
| .def( |
| "buffer", |
| &::c10d::GradBucket::getBuffer, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns: |
| A flattened 1D ``torch.Tensor`` buffer, |
| which can be further decomposed into a list of per-parameter tensors within this bucket. |
| )") |
| .def( |
| "gradients", |
| &::c10d::GradBucket::getGradients, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns: |
| A list of ``torch.Tensor``. Each tensor in the list corresponds to a gradient. |
| )") |
| .def( |
| "parameters", |
| &::c10d::GradBucket::getParameters, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns: |
| A list of ``torch.Tensor``. Each tensor in the list corresponds to a model |
| parameter. |
| )") |
| .def( |
| "is_last", |
| &::c10d::GradBucket::isLast, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns: |
| Whether this bucket is the last bucket to allreduce in an iteration. |
| This also means that this bucket corresponds to the first few layers in the forward pass. |
| )") |
| .def( |
| "set_buffer", |
| &::c10d::GradBucket::setBuffer, |
| py::arg("buffer"), |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Replaces the tensor in the bucket with the input tensor buffer. |
| )"); |
| |
| py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"( |
| An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)") |
| .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE) |
| .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS); |
| |
| shared_ptr_class_<::c10d::Reducer>(module, "Reducer") |
| .def( |
| py::init< |
| std::vector<at::Tensor>, |
| std::vector<std::vector<size_t>>, |
| std::vector<size_t>, |
| c10::intrusive_ptr<::c10d::ProcessGroup>, |
| std::vector<bool>, |
| int64_t, |
| bool, |
| bool, |
| std::unordered_map<size_t, std::string>, |
| int64_t>(), |
| py::arg("params"), |
| py::arg("bucket_indices"), |
| py::arg("per_bucket_size_limits"), |
| py::arg("process_group"), |
| py::arg("expect_sparse_gradients") = std::vector<bool>(), |
| py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap, |
| py::arg("find_unused_parameters") = false, |
| py::arg("gradient_as_bucket_view") = false, |
| py::arg("param_to_name_mapping") = |
| std::unordered_map<size_t, std::string>(), |
| py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "prepare_for_forward", |
| &::c10d::Reducer::prepare_for_forward, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "prepare_for_backward", |
| &::c10d::Reducer::prepare_for_backward, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "prepare_for_backward", |
| [](::c10d::Reducer& reducer, const at::Tensor& output) -> void { |
| reducer.prepare_for_backward({output}); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def("get_backward_stats", &::c10d::Reducer::get_backward_stats) |
| .def( |
| "_install_post_backward_futures", |
| [](::c10d::Reducer& reducer, |
| const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& |
| futs) { |
| c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futures( |
| c10::FutureType::create(c10::TensorType::get())); |
| for (const auto& fut : futs) { |
| futures.push_back(fut->fut); |
| } |
| reducer.install_futures(std::move(futures)); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_rebuild_buckets", |
| &::c10d::Reducer::rebuild_buckets, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_zeros_like_grad_buckets", |
| [](::c10d::Reducer& reducer) { |
| return reducer.get_grad_buckets(/* return_zero_tensors */ true); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_optimizer_in_backward", |
| [](::c10d::Reducer& reducer) { reducer.set_optimizer_in_backward(); }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_sparse_metadata", |
| &::c10d::Reducer::setSparseMetadata, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_mixed_precision_param_dtype", |
| [](::c10d::Reducer& reducer, py::object data_type_obj) { |
| auto scalar_type = |
| reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type; |
| reducer.set_mixed_precision_param_dtype(scalar_type); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_push_all_rebuilt_params", |
| &::c10d::Reducer::push_rebuilt_params_for_all_indices, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_forward_pass_work_handle", |
| &::c10d::Reducer::set_forward_pass_work_handle, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_local_used_map", &::c10d::Reducer::get_local_used_map_on_device) |
| .def( |
| "_set_ddp_runtime_logging_sample_rate", |
| &::c10d::Reducer::set_ddp_runtime_logging_sample_rate, |
| py::arg("sample_rate"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_static_graph", |
| &::c10d::Reducer::set_static_graph, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_ddp_graph_static", |
| &::c10d::Reducer::ddp_graph_static, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_delay_all_reduce", |
| &::c10d::Reducer::delay_all_reduce, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_run_comm_hook", |
| [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket) |
| -> std::shared_ptr<jit::PythonFutureWrapper> { |
| c10::intrusive_ptr<c10::ivalue::Future> fut = |
| reducer.run_comm_hook(bucket); |
| return std::make_shared<jit::PythonFutureWrapper>(fut); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_run_allreduce_hook", |
| [](::c10d::Reducer& reducer, ::c10d::GradBucket& bucket) |
| -> std::shared_ptr<jit::PythonFutureWrapper> { |
| c10::intrusive_ptr<c10::ivalue::Future> fut = |
| reducer.run_allreduce_hook(bucket); |
| return std::make_shared<jit::PythonFutureWrapper>(fut); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_autograd_hook", |
| [](::c10d::Reducer& reducer, int index) -> void { |
| reducer.autograd_hook(index); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "set_logger", |
| [](::c10d::Reducer& reducer, |
| const std::shared_ptr<::c10d::Logger> logger) { |
| std::weak_ptr<::c10d::Logger> logger_weakref = logger; |
| reducer.set_logger(logger_weakref); |
| }) |
| .def( |
| "_remove_autograd_hooks", |
| [](::c10d::Reducer& reducer) { reducer.remove_autograd_hooks(); }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_check_reducer_finalized", |
| [](::c10d::Reducer& reducer) { return reducer.check_finalized(); }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| shared_ptr_class_<::c10d::Logger>(module, "Logger") |
| .def( |
| py::init<std::shared_ptr<::c10d::Reducer>>(), |
| py::arg("reducer"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "set_construction_data_and_log", |
| &::c10d::Logger::set_construction_data_and_log, |
| py::arg("module_name"), |
| py::arg("device_ids"), |
| py::arg("output_device"), |
| py::arg("broadcast_buffers"), |
| py::arg("has_sync_bn"), |
| py::arg("static_graph"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "set_runtime_stats_and_log", |
| &::c10d::Logger::set_runtime_stats_and_log, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "set_error_and_log", |
| [](::c10d::Logger& logger, const std::string& error) { |
| logger.set_error_and_log(error); |
| }, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_ddp_logging_data", |
| &::c10d::Logger::get_ddp_logging_data, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_comm_hook_name", |
| &::c10d::Logger::set_comm_hook, |
| py::arg("comm_hook"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_uneven_input_join", |
| &::c10d::Logger::set_uneven_input_join, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_static_graph", |
| &::c10d::Logger::set_static_graph, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"( |
| An enum whose values correspond to different debug levels of the |
| torch.distributed package. Currently supporting OFF, INFO, and DETAIL, |
| which can be set via the TORCH_DISTRIBUTED_DEBUG environment variable |
| or via ``set_debug_level()`` function. |
| )") |
| .value("OFF", ::c10d::DebugLevel::Off) |
| .value("INFO", ::c10d::DebugLevel::Info) |
| .value("DETAIL", ::c10d::DebugLevel::Detail); |
| |
| module |
| .def( |
| "get_debug_level", |
| ::c10d::debug_level, |
| R"(Gets the debug level of the torch.distributed package.)") |
| .def( |
| "set_debug_level", |
| ::c10d::setDebugLevel, |
| R"(Sets the debug level of the torch.distributed package.)") |
| .def( |
| "set_debug_level_from_env", |
| ::c10d::setDebugLevelFromEnvironment, |
| R"(Sets the debug level of the torch.distributed package from the |
| ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)"); |
| |
| // TODO(crcrpar): Hardening `ReduceOp`. |
| // While keeping most op types as enum value, |
| // making `PREMUL_SUM` callable, i.e., allowing for |
| // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol. |
| // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types |
| py::class_<::c10d::ReduceOp> reduce_op( |
| module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"( |
| An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, |
| ``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``. |
| |
| ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when |
| using the ``NCCL`` backend. |
| |
| ``AVG`` divides values by the world size before summing across ranks. |
| ``AVG`` is only available with the ``NCCL`` backend, |
| and only for NCCL versions 2.10 or later. |
| |
| ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction. |
| ``PREMUL_SUM`` is only available with the ``NCCL`` backend, |
| and only available for NCCL versions 2.11 or later. Users are supposed to |
| use ``torch.distributed._make_nccl_premul_sum``. |
| |
| Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. |
| |
| The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. |
| They are used in specifying strategies for reduction collectives, e.g., |
| :func:`reduce`, :func:`all_reduce_multigpu`, etc. |
| |
| This class does not support ``__members__`` property.)"); |
| |
| reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>()) |
| .def_readwrite("op", &::c10d::ReduceOp::op_); |
| // The following are for some kind of backward compatibility. |
| // Since c10d::ReduceOp had been an `enum class`, users can do comparison and |
| // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here |
| // I define some member methods. |
| reduce_op |
| // todo(crcrpar): Support `RedOpType == ReduceOp`. |
| .def( |
| // This calls `operator==(const ReduceOp::RedOpType)` |
| "__eq__", |
| [](const ::c10d::ReduceOp& self, |
| const ::c10d::ReduceOp::RedOpType& other) { |
| return self == other; |
| }) |
| .def( |
| // This calls `operator==(const ReduceOp)` for the future support of |
| // `PREMUL_SUM` comparison |
| "__eq__", |
| [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) { |
| return self == other; |
| }) |
| .def( |
| // With the above custom `__eq__`'s, I have to manually support the |
| // other types. |
| "__eq__", |
| [](const ::c10d::ReduceOp& self, py::object) { return false; }) |
| .def( |
| "__hash__", |
| [](const ::c10d::ReduceOp& self) { |
| return static_cast<uint8_t>(self.op_); |
| }) |
| .def( |
| "__copy__", |
| [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); }) |
| .def( |
| "__deepcopy__", |
| [](const ::c10d::ReduceOp& self, const py::dict& memo) { |
| return ::c10d::ReduceOp(self); |
| }) |
| .def(py::pickle( |
| [](const ::c10d::ReduceOp& r) { |
| // __getstate__ |
| if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { |
| return py::make_tuple(r.op_, py::none()); |
| } |
| TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp"); |
| const auto* preMulSupplement = |
| reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>( |
| r.supplement_.get()); |
| if (!preMulSupplement->tensor_factor.defined()) { |
| return py::make_tuple(r.op_, preMulSupplement->double_factor); |
| } else { |
| return py::make_tuple(r.op_, preMulSupplement->tensor_factor); |
| } |
| }, |
| [](const py::tuple t) { |
| // __setstate__ |
| TORCH_CHECK(t.size() == 2, "Invalid state"); |
| const auto op = |
| static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>()); |
| if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { |
| return ::c10d::ReduceOp(op); |
| } |
| const auto preMulSupplement_factor = t[1]; |
| if (py::isinstance<py::float_>(preMulSupplement_factor)) { |
| return ::c10d::makeNCCLPreMulSum(t[1].cast<double>()); |
| } else { |
| return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>()); |
| } |
| })); |
| |
| py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType") |
| .value("SUM", ::c10d::ReduceOp::RedOpType::SUM) |
| .value("AVG", ::c10d::ReduceOp::RedOpType::AVG) |
| .value("PRODUCT", ::c10d::ReduceOp::RedOpType::PRODUCT) |
| .value("MIN", ::c10d::ReduceOp::RedOpType::MIN) |
| .value("MAX", ::c10d::ReduceOp::RedOpType::MAX) |
| .value("BAND", ::c10d::ReduceOp::RedOpType::BAND) |
| .value("BOR", ::c10d::ReduceOp::RedOpType::BOR) |
| .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR) |
| .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM) |
| .export_values(); |
| |
| // note(crcrpar): This could be removed because users will not pass |
| // `RedOpType` to reduce collective ops Ref: [Implicit |
| // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions) |
| // Let us skip the explicit construction of `c10d::ReduceOp` from |
| // `c10d::ReduceOp::RedOpType` in Python. |
| py::implicitly_convertible<::c10d::ReduceOp::RedOpType, ::c10d::ReduceOp>(); |
| |
| module |
| .def( |
| "_make_nccl_premul_sum", |
| &::c10d::makeNCCLPreMulSum<double>, |
| py::arg("factor").noconvert(), |
| py::return_value_policy::copy, // seems safest |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_make_nccl_premul_sum", |
| &::c10d::makeNCCLPreMulSum<at::Tensor>, |
| py::arg("factor").noconvert(), |
| py::return_value_policy::copy, // seems safest |
| py::call_guard<py::gil_scoped_release>()); |
| |
| py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions") |
| .def(py::init<>()) |
| .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank) |
| .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor) |
| .def_readwrite("timeout", &::c10d::BroadcastOptions::timeout); |
| |
| py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") |
| .def(py::init<>()) |
| .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp) |
| .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout); |
| |
| py::class_<::c10d::AllreduceCoalescedOptions>( |
| module, "AllreduceCoalescedOptions") |
| .def(py::init<>()) |
| .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp) |
| .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout); |
| |
| py::class_<::c10d::ReduceOptions>(module, "ReduceOptions") |
| .def(py::init<>()) |
| .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp) |
| .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank) |
| .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor) |
| .def_readwrite("timeout", &::c10d::ReduceOptions::timeout); |
| |
| py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions") |
| .def(py::init<>()) |
| .def_readwrite("timeout", &::c10d::AllgatherOptions::timeout); |
| |
| py::class_<::c10d::GatherOptions>(module, "GatherOptions") |
| .def(py::init<>()) |
| .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank) |
| .def_readwrite("timeout", &::c10d::GatherOptions::timeout); |
| |
| py::class_<::c10d::ScatterOptions>(module, "ScatterOptions") |
| .def(py::init<>()) |
| .def_readwrite("rootRank", &::c10d::ScatterOptions::rootRank) |
| .def_readwrite("timeout", &::c10d::ScatterOptions::timeout); |
| |
| py::class_<::c10d::ReduceScatterOptions>(module, "ReduceScatterOptions") |
| .def(py::init<>()) |
| .def_readwrite("reduceOp", &::c10d::ReduceScatterOptions::reduceOp) |
| .def_readwrite("timeout", &::c10d::ReduceScatterOptions::timeout); |
| |
| py::class_<::c10d::BarrierOptions>(module, "BarrierOptions") |
| .def(py::init<>()) |
| .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) |
| .def_readwrite("timeout", &::c10d::BarrierOptions::timeout) |
| .def_readwrite("device", &::c10d::BarrierOptions::device); |
| |
| py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") |
| .def(py::init<>()) |
| .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); |
| |
| py::class_<::c10d::DistributedBackendOptions>( |
| module, "_DistributedBackendOptions") |
| .def(py::init<>()) |
| .def_readwrite("store", &::c10d::DistributedBackendOptions::store) |
| .def_readwrite( |
| "group_rank", &::c10d::DistributedBackendOptions::group_rank) |
| .def_readwrite( |
| "group_size", &::c10d::DistributedBackendOptions::group_size) |
| .def_readwrite("timeout", &::c10d::DistributedBackendOptions::timeout) |
| .def_readwrite("group_id", &::c10d::DistributedBackendOptions::group_id) |
| .def_readwrite( |
| "global_ranks_in_group", |
| &::c10d::DistributedBackendOptions::global_ranks_in_group); |
| |
| auto store = |
| py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( |
| module, |
| "Store", |
| R"( |
| Base class for all store implementations, such as the 3 provided by PyTorch |
| distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`, |
| and :class:`~torch.distributed.HashStore`). |
| )") |
| // Default constructor. |
| .def(py::init<>()) |
| // Convert from std::string to std::vector<uint8>. |
| .def( |
| "set", |
| [](::c10d::Store& store, |
| const std::string& key, |
| const std::string& value) { |
| std::vector<uint8_t> value_(value.begin(), value.end()); |
| store.set(key, value_); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Inserts the key-value pair into the store based on the supplied ``key`` and |
| ``value``. If ``key`` already exists in the store, it will overwrite the old |
| value with the new supplied ``value``. |
| |
| Arguments: |
| key (str): The key to be added to the store. |
| value (str): The value associated with ``key`` to be added to the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("first_key", "first_value") |
| >>> # Should return "first_value" |
| >>> store.get("first_key") |
| )") |
| .def( |
| "compare_set", |
| [](::c10d::Store& store, |
| const std::string& key, |
| const std::string& expected_value, |
| const std::string& desired_value) -> py::bytes { |
| std::vector<uint8_t> expectedValue_( |
| expected_value.begin(), expected_value.end()); |
| std::vector<uint8_t> desiredValue_( |
| desired_value.begin(), desired_value.end()); |
| auto value = |
| store.compareSet(key, expectedValue_, desiredValue_); |
| return py::bytes( |
| reinterpret_cast<char*>(value.data()), value.size()); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Inserts the key-value pair into the store based on the supplied ``key`` and |
| performs comparison between ``expected_value`` and ``desired_value`` before inserting. ``desired_value`` |
| will only be set if ``expected_value`` for the ``key`` already exists in the store or if ``expected_value`` |
| is an empty string. |
| |
| Arguments: |
| key (str): The key to be checked in the store. |
| expected_value (str): The value associated with ``key`` to be checked before insertion. |
| desired_value (str): The value associated with ``key`` to be added to the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("key", "first_value") |
| >>> store.compare_set("key", "first_value", "second_value") |
| >>> # Should return "second_value" |
| >>> store.get("key") |
| )") |
| // Convert from std::vector<uint8_t> to py::bytes. |
| // The returned value is not guaranteed to be valid UTF-8. |
| .def( |
| "get", |
| [](::c10d::Store& store, const std::string& key) -> py::bytes { |
| auto value = [&]() { |
| py::gil_scoped_release guard; |
| return store.get(key); |
| }(); |
| return py::bytes( |
| reinterpret_cast<char*>(value.data()), value.size()); |
| }, |
| R"( |
| Retrieves the value associated with the given ``key`` in the store. If ``key`` is not |
| present in the store, the function will wait for ``timeout``, which is defined |
| when initializing the store, before throwing an exception. |
| |
| Arguments: |
| key (str): The function will return the value associated with this key. |
| |
| Returns: |
| Value associated with ``key`` if ``key`` is in the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("first_key", "first_value") |
| >>> # Should return "first_value" |
| >>> store.get("first_key") |
| )") |
| .def( |
| "add", |
| &::c10d::Store::add, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| The first call to add for a given ``key`` creates a counter associated |
| with ``key`` in the store, initialized to ``amount``. Subsequent calls to add |
| with the same ``key`` increment the counter by the specified ``amount``. |
| Calling :meth:`~torch.distributed.store.add` with a key that has already |
| been set in the store by :meth:`~torch.distributed.store.set` will result |
| in an exception. |
| |
| Arguments: |
| key (str): The key in the store whose counter will be incremented. |
| amount (int): The quantity by which the counter will be incremented. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, other store types can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.add("first_key", 1) |
| >>> store.add("first_key", 6) |
| >>> # Should return 7 |
| >>> store.get("first_key") |
| )") |
| .def( |
| "delete_key", |
| &::c10d::Store::deleteKey, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Deletes the key-value pair associated with ``key`` from the store. Returns |
| `true` if the key was successfully deleted, and `false` if it was not. |
| |
| .. warning:: |
| The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore` and :class:`~torch.distributed.HashStore`. Using this API |
| with the :class:`~torch.distributed.FileStore` will result in an exception. |
| |
| Arguments: |
| key (str): The key to be deleted from the store |
| |
| Returns: |
| `True` if ``key`` was deleted, otherwise `False`. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, HashStore can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("first_key") |
| >>> # This should return true |
| >>> store.delete_key("first_key") |
| >>> # This should return false |
| >>> store.delete_key("bad_key") |
| )") |
| .def( |
| "num_keys", |
| &::c10d::Store::getNumKeys, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Returns the number of keys set in the store. Note that this number will typically |
| be one greater than the number of keys added by :meth:`~torch.distributed.store.set` |
| and :meth:`~torch.distributed.store.add` since one key is used to coordinate all |
| the workers using the store. |
| |
| .. warning:: |
| When used with the :class:`~torch.distributed.TCPStore`, ``num_keys`` returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained. |
| |
| Returns: |
| The number of keys present in the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, other store types can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("first_key", "first_value") |
| >>> # This should return 2 |
| >>> store.num_keys() |
| )") |
| .def( |
| "set_timeout", |
| &::c10d::Store::setTimeout, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Sets the store's default timeout. This timeout is used during initialization and in |
| :meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`. |
| |
| Arguments: |
| timeout (timedelta): timeout to be set in the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, other store types can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set_timeout(timedelta(seconds=10)) |
| >>> # This will throw an exception after 10 seconds |
| >>> store.wait(["bad_key"]) |
| )") |
| .def( |
| "wait", |
| [](::c10d::Store& store, const std::vector<std::string>& keys) { |
| store.wait(keys); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Waits for each key in ``keys`` to be added to the store. If not all keys are |
| set before the ``timeout`` (set during store initialization), then ``wait`` |
| will throw an exception. |
| |
| Arguments: |
| keys (list): List of keys on which to wait until they are set in the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, other store types can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> # This will throw an exception after 30 seconds |
| >>> store.wait(["bad_key"]) |
| )") |
| .def( |
| "wait", |
| [](::c10d::Store& store, |
| const std::vector<std::string>& keys, |
| const std::chrono::milliseconds& timeout) { |
| store.wait(keys, timeout); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Waits for each key in ``keys`` to be added to the store, and throws an exception |
| if the keys have not been set by the supplied ``timeout``. |
| |
| Arguments: |
| keys (list): List of keys on which to wait until they are set in the store. |
| timeout (timedelta): Time to wait for the keys to be added before throwing an exception. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Using TCPStore as an example, other store types can also be used |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> # This will throw an exception after 10 seconds |
| >>> store.wait(["bad_key"], timedelta(seconds=10)) |
| )") |
| .def_property_readonly( |
| "timeout", |
| &::c10d::Store::getTimeout, |
| R"(Gets the timeout of the store.)") |
| .def( |
| "append", |
| [](::c10d::Store& store, |
| const std::string& key, |
| const std::string& value) { |
| std::vector<uint8_t> value_(value.begin(), value.end()); |
| store.append(key, value_); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Append the key-value pair into the store based on the supplied ``key`` and |
| ``value``. If ``key`` does not exists in the store, it will be created. |
| |
| Arguments: |
| key (str): The key to be appended to the store. |
| value (str): The value associated with ``key`` to be added to the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.append("first_key", "po") |
| >>> store.append("first_key", "tato") |
| >>> # Should return "potato" |
| >>> store.get("first_key") |
| )") |
| .def( |
| "multi_get", |
| [](::c10d::Store& store, const std::vector<std::string>& keys) { |
| auto values = [&]() { |
| py::gil_scoped_release guard; |
| return store.multiGet(keys); |
| }(); |
| std::vector<py::bytes> res; |
| for (auto& value : values) { |
| auto bytes = py::bytes( |
| reinterpret_cast<const char*>(value.data()), |
| value.size()); |
| res.push_back(bytes); |
| } |
| return res; |
| }, |
| R"( |
| Retrieve all values in ``keys``. If any key in ``keys`` is not |
| present in the store, the function will wait for ``timeout`` |
| |
| Arguments: |
| keys (List[str]): The keys to be retrieved from the store. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.set("first_key", "po") |
| >>> store.set("second_key", "tato") |
| >>> # Should return [b"po", b"tato"] |
| >>> store.multi_get(["first_key", "second_key"]) |
| )") |
| .def( |
| "multi_set", |
| [](::c10d::Store& store, |
| const std::vector<std::string>& keys, |
| const std::vector<std::string>& values) { |
| std::vector<std::vector<uint8_t>> vals; |
| for (auto& value : values) { |
| vals.push_back( |
| std::vector<uint8_t>(value.begin(), value.end())); |
| } |
| store.multiSet(keys, vals); |
| }, |
| py::call_guard<py::gil_scoped_release>(), |
| R"( |
| Inserts a list key-value pair into the store based on the supplied ``keys`` and ``values`` |
| |
| Arguments: |
| keys (List[str]): The keys to insert. |
| values (List[str]): The values to insert. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30)) |
| >>> store.multi_set(["first_key", "second_key"], ["po", "tato"]) |
| >>> # Should return b"po" |
| >>> store.get("first_key") |
| )") |
| .def( |
| "has_extended_api", |
| &::c10d::Store::hasExtendedApi, |
| R"(Returns true if the store supports extended operations.)"); |
| |
| intrusive_ptr_class_<::c10d::FileStore>( |
| module, |
| "FileStore", |
| store, |
| R"( |
| A store implementation that uses a file to store the underlying key-value pairs. |
| |
| Arguments: |
| file_name (str): path of the file in which to store the key-value pairs |
| world_size (int, optional): The total number of processes using the store. Default is -1 (a negative value indicates a non-fixed number of store users). |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> store1 = dist.FileStore("/tmp/filestore", 2) |
| >>> store2 = dist.FileStore("/tmp/filestore", 2) |
| >>> # Use any of the store methods from either the client or server after initialization |
| >>> store1.set("first_key", "first_value") |
| >>> store2.get("first_key") |
| |
| )") |
| .def( |
| py::init<const std::string&, int>(), |
| py::arg("file_name"), |
| py::arg("world_size") = -1) |
| .def_property_readonly( |
| "path", |
| &::c10d::FileStore::getPath, |
| R"(Gets the path of the file used by FileStore to store key-value pairs.)"); |
| |
| #ifndef _WIN32 |
| intrusive_ptr_class_<::c10d::HashStore>( |
| module, |
| "HashStore", |
| store, |
| R"( |
| A thread-safe store implementation based on an underlying hashmap. This store can be used |
| within the same process (for example, by other threads), but cannot be used across processes. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> store = dist.HashStore() |
| >>> # store can be used from other threads |
| >>> # Use any of the store methods after initialization |
| >>> store.set("first_key", "first_value") |
| )") |
| .def(py::init<>()); |
| #endif |
| |
| intrusive_ptr_class_<::c10d::TCPStore>( |
| module, |
| "TCPStore", |
| store, |
| R"( |
| A TCP-based distributed key-value store implementation. The server store holds |
| the data, while the client stores can connect to the server store over TCP and |
| perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value |
| pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc. There |
| should always be one server store initialized because the client store(s) will wait for |
| the server to establish a connection. |
| |
| Arguments: |
| host_name (str): The hostname or IP Address the server store should run on. |
| port (int): The port on which the server store should listen for incoming requests. |
| world_size (int, optional): The total number of store users (number of clients + 1 for the server). Default is None (None indicates a non-fixed number of store users). |
| is_master (bool, optional): True when initializing the server store and False for client stores. Default is False. |
| timeout (timedelta, optional): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`. Default is timedelta(seconds=300) |
| wait_for_worker (bool, optional): Whether to wait for all the workers to connect with the server store. This is only applicable when world_size is a fixed value. Default is True. |
| multi_tenant (bool, optional): If True, all ``TCPStore`` instances in the current process with the same host/port will use the same underlying ``TCPServer``. Default is False. |
| master_listen_fd (int, optional): If specified, the underlying ``TCPServer`` will listen on this file descriptor, which must be a socket already bound to ``port``. Useful to avoid port assignment races in some scenarios. Default is None (meaning the server creates a new socket and attempts to bind it to ``port``). |
| Example:: |
| >>> import torch.distributed as dist |
| >>> from datetime import timedelta |
| >>> # Run on process 1 (server) |
| >>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30)) |
| >>> # Run on process 2 (client) |
| >>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False) |
| >>> # Use any of the store methods from either the client or server after initialization |
| >>> server_store.set("first_key", "first_value") |
| >>> client_store.get("first_key") |
| )") |
| .def( |
| py::init([](const std::string& host, |
| uint16_t port, |
| c10::optional<int> worldSize, |
| bool isServer, |
| std::chrono::milliseconds timeout, |
| bool waitWorkers, |
| bool multiTenant, |
| c10::optional<int> masterListenFd, |
| bool useLibUV) { |
| c10::optional<std::size_t> numWorkers = c10::nullopt; |
| if (worldSize.has_value() && worldSize.value() > -1) { |
| numWorkers = static_cast<std::size_t>(worldSize.value()); |
| } |
| |
| ::c10d::TCPStoreOptions opts{ |
| port, |
| isServer, |
| numWorkers, |
| waitWorkers, |
| timeout, |
| multiTenant, |
| masterListenFd, |
| useLibUV}; |
| |
| return c10::make_intrusive<::c10d::TCPStore>(host, opts); |
| }), |
| py::arg("host_name"), |
| py::arg("port"), |
| py::arg("world_size") = py::none(), |
| // using noconvert() requires this argument to be True or False |
| // prevents accidental implicit conversion to bool |
| py::arg("is_master").noconvert() = false, |
| py::arg("timeout") = |
| std::chrono::milliseconds(::c10d::Store::kDefaultTimeout), |
| py::arg("wait_for_workers") = true, |
| py::arg("multi_tenant") = false, |
| py::arg("master_listen_fd") = py::none(), |
| py::arg("use_libuv") = false, |
| py::call_guard<py::gil_scoped_release>()) |
| .def_property_readonly( |
| "host", |
| &::c10d::TCPStore::getHost, |
| R"(Gets the hostname on which the store listens for requests.)") |
| |
| .def_property_readonly( |
| "port", |
| &::c10d::TCPStore::getPort, |
| R"(Gets the port number on which the store listens for requests.)"); |
| |
| intrusive_ptr_class_<::c10d::PrefixStore>( |
| module, |
| "PrefixStore", |
| store, |
| R"( |
| A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`, |
| :class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`) |
| that adds a prefix to each key inserted to the store. |
| |
| Arguments: |
| prefix (str): The prefix string that is prepended to each key before being inserted into the store. |
| store (torch.distributed.store): A store object that forms the underlying key-value store. |
| )") |
| .def(py::init<const std::string&, c10::intrusive_ptr<::c10d::Store>>()) |
| .def_property_readonly( |
| "underlying_store", |
| &::c10d::PrefixStore::getUnderlyingStore, |
| R"(Gets the underlying store object that PrefixStore wraps around.)"); |
| |
| auto processGroup = |
| py::class_< |
| ::c10d::ProcessGroup, |
| c10::intrusive_ptr<::c10d::ProcessGroup>, |
| ::c10d::PyProcessGroup>(module, "ProcessGroup") |
| .def(py::init<int, int>()) |
| .def( |
| py::init< |
| const c10::intrusive_ptr<::c10d::Store>&, |
| int, |
| int, |
| c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def("rank", &::c10d::ProcessGroup::getRank) |
| .def("size", &::c10d::ProcessGroup::getSize) |
| .def("name", &::c10d::ProcessGroup::getBackendName) |
| .def("_id", &::c10d::ProcessGroup::getID) |
| .def( |
| "_backend_id", |
| &::c10d::ProcessGroup::getBackendID, |
| py::arg("backend_type")) |
| .def_property_readonly("options", &::c10d::ProcessGroup::getOptions) |
| .def( |
| "broadcast", |
| &::c10d::ProcessGroup::broadcast, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::BroadcastOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "broadcast", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| at::Tensor& x, |
| int rootRank) { |
| ::c10d::BroadcastOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<at::Tensor> tensors = {x}; |
| return self->broadcast(tensors, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce", |
| &::c10d::ProcessGroup::allreduce, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::AllreduceOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| std::vector<at::Tensor>& xs, |
| ::c10d::ReduceOp op) { |
| ::c10d::AllreduceOptions opts; |
| opts.reduceOp = op; |
| return self->allreduce(xs, opts); |
| }, |
| py::arg("tensors"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| |
| .def( |
| "allreduce", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| at::Tensor& x, |
| ::c10d::ReduceOp op) { |
| ::c10d::AllreduceOptions opts; |
| opts.reduceOp = op; |
| std::vector<at::Tensor> xs = {x}; |
| return self->allreduce(xs, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce_coalesced", |
| &::c10d::ProcessGroup::allreduce_coalesced, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| |
| .def( |
| "reduce", |
| &::c10d::ProcessGroup::reduce, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::ReduceOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| |
| .def( |
| "reduce", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| at::Tensor& x, |
| int rootRank, |
| ::c10d::ReduceOp op) { |
| ::c10d::ReduceOptions opts; |
| opts.reduceOp = op; |
| opts.rootRank = rootRank; |
| std::vector<at::Tensor> xs = {x}; |
| return self->reduce(xs, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("root"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather", |
| &::c10d::ProcessGroup::allgather, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| std::vector<at::Tensor>& output, |
| at::Tensor& input) { |
| std::vector<std::vector<at::Tensor>> outputs = {output}; |
| std::vector<at::Tensor> inputs = {input}; |
| return self->allgather( |
| outputs, inputs, ::c10d::AllgatherOptions()); |
| }, |
| py::arg("output_tensors"), |
| py::arg("input_tensor"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_allgather_base", |
| &::c10d::ProcessGroup::_allgather_base, |
| py::arg("output"), |
| py::arg("input"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather_coalesced", |
| &::c10d::ProcessGroup::allgather_coalesced, |
| py::arg("output_lists"), |
| py::arg("input_list"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather_into_tensor_coalesced", |
| &::c10d::ProcessGroup::allgather_into_tensor_coalesced, |
| py::arg("outputs"), |
| py::arg("inputs"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "gather", |
| &::c10d::ProcessGroup::gather, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::GatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| |
| .def( |
| "gather", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| std::vector<at::Tensor>& output, |
| at::Tensor& input, |
| int rootRank) { |
| ::c10d::GatherOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<std::vector<at::Tensor>> outputs = {output}; |
| std::vector<at::Tensor> inputs = {input}; |
| return self->gather(outputs, inputs, opts); |
| }, |
| py::arg("output_tensors"), |
| py::arg("input_tensor"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "scatter", |
| &::c10d::ProcessGroup::scatter, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::ScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "scatter", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| at::Tensor& output, |
| std::vector<at::Tensor>& input, |
| int rootRank) { |
| ::c10d::ScatterOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<std::vector<at::Tensor>> inputs = {input}; |
| std::vector<at::Tensor> outputs = {output}; |
| return self->scatter(outputs, inputs, opts); |
| }, |
| py::arg("output_tensor"), |
| py::arg("input_tensors"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce_scatter", |
| &::c10d::ProcessGroup::reduce_scatter, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::ReduceScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce_scatter", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| at::Tensor& output, |
| std::vector<at::Tensor>& input, |
| ::c10d::ReduceOp op) { |
| std::vector<at::Tensor> outputs = {output}; |
| std::vector<std::vector<at::Tensor>> inputs = {input}; |
| ::c10d::ReduceScatterOptions opts; |
| opts.reduceOp = op; |
| return self->reduce_scatter(outputs, inputs, opts); |
| }, |
| py::arg("output"), |
| py::arg("input"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_reduce_scatter_base", |
| &::c10d::ProcessGroup::_reduce_scatter_base, |
| py::arg("outputTensor"), |
| py::arg("inputTensor"), |
| py::arg("opts") = ::c10d::ReduceScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce_scatter_tensor_coalesced", |
| &::c10d::ProcessGroup::reduce_scatter_tensor_coalesced, |
| py::arg("outputs"), |
| py::arg("inputs"), |
| py::arg("opts") = ::c10d::ReduceScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "alltoall_base", |
| &::c10d::ProcessGroup::alltoall_base, |
| py::arg("output"), |
| py::arg("input"), |
| py::arg("output_split_sizes"), |
| py::arg("input_split_sizes"), |
| py::arg("opts") = ::c10d::AllToAllOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "alltoall", |
| &::c10d::ProcessGroup::alltoall, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::AllToAllOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "send", |
| &::c10d::ProcessGroup::send, |
| py::arg("tensors"), |
| py::arg("dstRank"), |
| py::arg("tag"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "recv", |
| &::c10d::ProcessGroup::recv, |
| py::arg("tensors"), |
| py::arg("srcRank"), |
| py::arg("tag"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "recv_anysource", |
| &::c10d::ProcessGroup::recvAnysource, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "barrier", |
| &::c10d::ProcessGroup::barrier, |
| py::arg("opts") = ::c10d::BarrierOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_sequence_number_for_group", |
| &::c10d::ProcessGroup::setSequenceNumberForGroup, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_sequence_number_for_group", |
| &::c10d::ProcessGroup::getSequenceNumberForGroup, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "monitored_barrier", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| const std::chrono::milliseconds& timeout, |
| bool waitAllRanks) { |
| ::c10d::BarrierOptions opts; |
| opts.timeout = timeout; |
| return self->monitoredBarrier(opts, waitAllRanks); |
| }, |
| py::arg("timeout") = ::c10d::kUnsetTimeout, |
| py::arg("wait_all_ranks") = false, |
| py::call_guard<py::gil_scoped_release>()) |
| .def_property_readonly( |
| "_device_types", &::c10d::ProcessGroup::getDeviceTypes) |
| .def( |
| "_get_backend_name", |
| &::c10d::ProcessGroup::getBackendName, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_start_coalescing", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| const c10::Device& device) { |
| self->startCoalescing(device.type()); |
| }, |
| py::arg("device_type"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_end_coalescing", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| const c10::Device& device) { |
| return self->endCoalescing(device.type()); |
| }, |
| py::arg("device_type"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_register_backend", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| const c10::Device& device, |
| const ::c10d::ProcessGroup::BackendType& backendType, |
| const c10::optional<c10::intrusive_ptr<::c10d::Backend>>& |
| backend) { |
| self->setBackend(device.type(), backendType, backend); |
| }, |
| py::arg("device"), |
| py::arg("backend_type"), |
| py::arg("backend") = |
| c10::optional<c10::intrusive_ptr<::c10d::Backend>>(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_backend", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, |
| const c10::Device& device) { |
| return self->getBackend(device.type()); |
| }, |
| py::arg("device"), |
| py::call_guard<py::gil_scoped_release>()); |
| |
| py::enum_<::c10d::ProcessGroup::BackendType>(processGroup, "BackendType") |
| .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) |
| .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) |
| .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) |
| .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) |
| .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) |
| .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) |
| .export_values(); |
| |
| // base ProcessGroup::Options binding |
| auto processGroupOptions = |
| intrusive_ptr_class_<::c10d::ProcessGroup::Options>( |
| processGroup, |
| "Options", |
| R"( |
| Base class for all processes group options implementations, such as the nccl |
| options :class:`~torch.distributed.ProcessGroupNCCL.Options`). |
| )") |
| .def( |
| py::init([](const std::string& backend, |
| const std::chrono::milliseconds& timeout) { |
| return c10::make_intrusive<::c10d::ProcessGroup::Options>( |
| backend, timeout); |
| }), |
| py::arg("backend"), |
| py::arg("timeout") = kProcessGroupDefaultTimeout, |
| py::call_guard<py::gil_scoped_release>()) |
| .def_readonly("backend", &::c10d::ProcessGroup::Options::backend) |
| .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout); |
| |
| #ifndef _WIN32 |
| module.def( |
| "_round_robin_process_groups", |
| [](std::vector<c10::intrusive_ptr<::c10d::ProcessGroup>> processGroups) |
| -> c10::intrusive_ptr<::c10d::ProcessGroup> { |
| if (processGroups.empty()) { |
| throw std::invalid_argument("Specify at least 1 process group"); |
| } |
| const auto& first = processGroups.front(); |
| return c10::make_intrusive<::c10d::ProcessGroupRoundRobin>( |
| first->getRank(), first->getSize(), std::move(processGroups)); |
| }, |
| py::arg("process_groups"), |
| py::call_guard<py::gil_scoped_release>()); |
| #endif |
| |
| // TODO: The collection definitions handles direct instantiation of |
| // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported |
| // and should be removed once all tests are transitioned |
| auto backend = |
| py::class_<::c10d::Backend, c10::intrusive_ptr<::c10d::Backend>>( |
| module, "Backend") |
| .def("rank", &::c10d::Backend::getRank) |
| .def("size", &::c10d::Backend::getSize) |
| .def("name", &::c10d::Backend::getBackendName) |
| .def( |
| "broadcast", |
| &::c10d::Backend::broadcast, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::BroadcastOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "broadcast", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| at::Tensor& x, |
| int rootRank) { |
| ::c10d::BroadcastOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<at::Tensor> xs = {x}; |
| return self->broadcast(xs, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce", |
| &::c10d::Backend::allreduce, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::AllreduceOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| std::vector<at::Tensor>& xs, |
| ::c10d::ReduceOp op) { |
| ::c10d::AllreduceOptions opts; |
| opts.reduceOp = op; |
| return self->allreduce(xs, opts); |
| }, |
| py::arg("tensors"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| at::Tensor& x, |
| ::c10d::ReduceOp op) { |
| ::c10d::AllreduceOptions opts; |
| opts.reduceOp = op; |
| std::vector<at::Tensor> xs = {x}; |
| return self->allreduce(xs, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allreduce_coalesced", |
| &::c10d::Backend::allreduce_coalesced, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce", |
| &::c10d::Backend::reduce, |
| py::arg("tensors"), |
| py::arg("opts") = ::c10d::ReduceOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| at::Tensor& x, |
| int rootRank, |
| ::c10d::ReduceOp op) { |
| ::c10d::ReduceOptions opts; |
| opts.reduceOp = op; |
| opts.rootRank = rootRank; |
| std::vector<at::Tensor> xs = {x}; |
| return self->reduce(xs, opts); |
| }, |
| py::arg("tensor"), |
| py::arg("root"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather", |
| &::c10d::Backend::allgather, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_allgather_base", |
| &::c10d::Backend::_allgather_base, |
| py::arg("output"), |
| py::arg("input"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| std::vector<at::Tensor>& output, |
| at::Tensor& input) { |
| std::vector<std::vector<at::Tensor>> outputs = {output}; |
| std::vector<at::Tensor> inputs = {input}; |
| return self->allgather( |
| outputs, inputs, ::c10d::AllgatherOptions()); |
| }, |
| py::arg("output_tensors"), |
| py::arg("input_tensor"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "allgather_coalesced", |
| &::c10d::Backend::allgather_coalesced, |
| py::arg("output_lists"), |
| py::arg("input_list"), |
| py::arg("opts") = ::c10d::AllgatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "gather", |
| &::c10d::Backend::gather, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::GatherOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "gather", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| std::vector<at::Tensor>& output, |
| at::Tensor& input, |
| int rootRank) { |
| ::c10d::GatherOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<std::vector<at::Tensor>> outputs = {output}; |
| std::vector<at::Tensor> inputs = {input}; |
| return self->gather(outputs, inputs, opts); |
| }, |
| py::arg("output_tensors"), |
| py::arg("input_tensor"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "scatter", |
| &::c10d::Backend::scatter, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::ScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "scatter", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| at::Tensor& output, |
| std::vector<at::Tensor>& input, |
| int rootRank) { |
| ::c10d::ScatterOptions opts; |
| opts.rootRank = rootRank; |
| std::vector<std::vector<at::Tensor>> inputs = {input}; |
| std::vector<at::Tensor> outputs = {output}; |
| return self->scatter(outputs, inputs, opts); |
| }, |
| py::arg("output_tensor"), |
| py::arg("input_tensors"), |
| py::arg("root"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce_scatter", |
| &::c10d::Backend::reduce_scatter, |
| py::arg("output_tensors"), |
| py::arg("input_tensors"), |
| py::arg("opts") = ::c10d::ReduceScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "reduce_scatter", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| at::Tensor& output, |
| std::vector<at::Tensor>& input, |
| ::c10d::ReduceOp op) { |
| std::vector<at::Tensor> outputs = {output}; |
| std::vector<std::vector<at::Tensor>> inputs = {input}; |
| ::c10d::ReduceScatterOptions opts; |
| opts.reduceOp = op; |
| return self->reduce_scatter(outputs, inputs, opts); |
| }, |
| py::arg("output_tensors"), |
| py::arg("input_tensor"), |
| py::arg("op") = ::c10d::ReduceOp::SUM, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_reduce_scatter_base", |
| &::c10d::Backend::_reduce_scatter_base, |
| py::arg("outputTensor"), |
| py::arg("inputTensor"), |
| py::arg("opts") = ::c10d::ReduceScatterOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "alltoall_base", |
| &::c10d::Backend::alltoall_base, |
| py::arg("output_tensor"), |
| py::arg("input_tensor"), |
| py::arg("output_split_sizes"), |
| py::arg("input_split_sizes"), |
| py::arg("opts") = ::c10d::AllToAllOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "alltoall_base", |
| [](::c10d::Backend& self, |
| at::Tensor& output, |
| at::Tensor& input, |
| std::vector<int64_t> outputSplitSizes, |
| std::vector<int64_t> inputSplitSizes) { |
| return self.alltoall_base( |
| output, |
| input, |
| outputSplitSizes, |
| inputSplitSizes, |
| ::c10d::AllToAllOptions()); |
| }, |
| py::arg("output"), |
| py::arg("input"), |
| py::arg("output_split_sizes"), |
| py::arg("input_split_sizes"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "alltoall", |
| &::c10d::Backend::alltoall, |
| py::arg("output_tensor"), |
| py::arg("input_tensor"), |
| py::arg("opts") = ::c10d::AllToAllOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "send", |
| &::c10d::Backend::send, |
| py::arg("tensors"), |
| py::arg("dstRank"), |
| py::arg("tag"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "recv", |
| &::c10d::Backend::recv, |
| py::arg("tensors"), |
| py::arg("srcRank"), |
| py::arg("tag"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "recv_anysource", |
| &::c10d::Backend::recvAnysource, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "barrier", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| const ::c10d::BarrierOptions& opts) { |
| return self->barrier(opts); |
| }, |
| py::arg("opts") = ::c10d::BarrierOptions(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_set_sequence_number_for_group", |
| &::c10d::Backend::setSequenceNumberForGroup, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_sequence_number_for_group", |
| &::c10d::Backend::getSequenceNumberForGroup, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "monitored_barrier", |
| [](const c10::intrusive_ptr<::c10d::Backend>& self, |
| const std::chrono::milliseconds& timeout, |
| bool waitAllRanks) { |
| ::c10d::BarrierOptions opts; |
| opts.timeout = timeout; |
| return self->monitoredBarrier(opts, waitAllRanks); |
| }, |
| py::arg("timeout") = ::c10d::kUnsetTimeout, |
| py::arg("wait_all_ranks") = false, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_get_backend_name", |
| &::c10d::Backend::getBackendName, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_start_coalescing", |
| &::c10d::Backend::startCoalescing, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_end_coalescing", |
| &::c10d::Backend::endCoalescing, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| #ifdef USE_C10D_GLOO |
| static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; |
| |
| auto processGroupGloo = |
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>( |
| module, "ProcessGroupGloo", backend); |
| |
| shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); |
| |
| intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( |
| processGroupGloo, "_Options", processGroupOptions) |
| .def(py::init<>()) |
| .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) |
| .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); |
| |
| processGroupGloo |
| .def_static( |
| "create_device", |
| [](const std::string& hostname, const std::string& interface) |
| -> std::shared_ptr<::gloo::transport::Device> { |
| if (!hostname.empty()) { |
| return ::c10d::ProcessGroupGloo::createDeviceForHostname( |
| hostname); |
| } |
| if (!interface.empty()) { |
| return ::c10d::ProcessGroupGloo::createDeviceForInterface( |
| interface); |
| } |
| throw std::invalid_argument( |
| "Specify either `hostname` or `interface` argument."); |
| }, |
| py::arg("hostname") = "", |
| py::arg("interface") = "") |
| .def_static( |
| "create_default_device", |
| &::c10d::ProcessGroupGloo::createDefaultDevice); |
| |
| processGroupGloo |
| .def( |
| py::init< |
| const c10::intrusive_ptr<::c10d::Store>&, |
| int, |
| int, |
| c10::intrusive_ptr<::c10d::ProcessGroupGloo::Options>>(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
| int rank, |
| int size, |
| std::chrono::milliseconds timeout) { |
| auto options = ::c10d::ProcessGroupGloo::Options::create(); |
| |
| // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. |
| char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); |
| if (ifnameEnv && strlen(ifnameEnv) > 1) { |
| for (const auto& iface : ::c10d::split(',', ifnameEnv)) { |
| options->devices.push_back( |
| ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); |
| } |
| } else { |
| // If no hostname is specified, this function looks up |
| // the machine's hostname and returns a device instance |
| // associated with the address that the hostname resolves to. |
| options->devices.push_back( |
| ::c10d::ProcessGroupGloo::createDefaultDevice()); |
| } |
| |
| options->timeout = timeout; |
| // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
| options->threads = options->devices.size() * 2; |
| return c10::make_intrusive<::c10d::ProcessGroupGloo>( |
| store, rank, size, options); |
| }), |
| py::arg("store"), |
| py::arg("rank"), |
| py::arg("size"), |
| py::arg("timeout") = kProcessGroupDefaultTimeout, |
| py::call_guard<py::gil_scoped_release>()) |
| .def_property_readonly("options", &::c10d::ProcessGroupGloo::getOptions); |
| |
| // ProcessGroupWrapper is a wrapper pg that includes a helper gloo process |
| // group. It can be used to validate collective calls across processes by |
| // checking the op type and input tensor shapes. |
| auto processGroupWrapper = |
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupWrapper>( |
| module, "_ProcessGroupWrapper", backend) |
| .def( |
| py::init( |
| [](const c10::intrusive_ptr<::c10d::Backend>& backend, |
| const c10::intrusive_ptr<::c10d::Backend>& gloo_backend) { |
| return c10::make_intrusive<::c10d::ProcessGroupWrapper>( |
| backend, gloo_backend); |
| }), |
| py::arg("backend"), |
| py::arg("gloo_backend"), |
| py::call_guard<py::gil_scoped_release>()) |
| .def_property_readonly( |
| "wrapped_pg", &::c10d::ProcessGroupWrapper::getWrappedPg); |
| #endif |
| |
| #ifdef USE_C10D_NCCL |
| auto processGroupNCCL = |
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupNCCL>( |
| module, "ProcessGroupNCCL", backend) |
| .def( |
| py::init< |
| const c10::intrusive_ptr<::c10d::Store>&, |
| int, |
| int, |
| c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options>>(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
| int rank, |
| int size, |
| const std::chrono::milliseconds& timeout) { |
| auto options = ::c10d::ProcessGroupNCCL::Options::create(); |
| options->is_high_priority_stream = false; |
| options->timeout = timeout; |
| return c10::make_intrusive<::c10d::ProcessGroupNCCL>( |
| store, rank, size, options); |
| }), |
| py::arg("store"), |
| py::arg("rank"), |
| py::arg("size"), |
| py::arg("timeout") = kProcessGroupDefaultTimeout, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "_abort", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, |
| const c10::optional<std::string>& abortReason) { |
| return self->abort(abortReason); |
| }, |
| py::arg("abort_reason") = py::none(), |
| py::call_guard<py::gil_scoped_release>()) |
| .def("_group_start", &::c10d::ProcessGroupNCCL::groupStart) |
| .def("_group_end", &::c10d::ProcessGroupNCCL::groupEnd) |
| .def_property_readonly( |
| "options", &::c10d::ProcessGroupNCCL::getOptions) |
| .def_property_readonly( |
| "is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable); |
| |
| #ifdef NCCL_HAS_COMM_CTA_CGA |
| py::class_<ncclConfig_t>( |
| processGroupNCCL, |
| "NCCLConfig", |
| R"( |
| ncclConfig_t data type for configuring NCCL communicators. |
| See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t |
| for details. |
| )") |
| .def(py::init<>()) |
| .def_readwrite("blocking", &ncclConfig_t::blocking) |
| .def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize) |
| .def_readwrite("min_ctas", &ncclConfig_t::minCTAs) |
| .def_readwrite("max_ctas", &ncclConfig_t::maxCTAs) |
| .def_property( |
| "net_name", |
| [](const ncclConfig_t& self) { return self.netName; }, |
| // Note: NCCL calls free on the netName pointer |
| // when destroying the communicator. So memory |
| // shouldn't leak because of allocation in strdup. |
| [](ncclConfig_t& self, const char* tmp) { |
| self.netName = strdup(tmp); |
| }); |
| #endif |
| |
| intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( |
| processGroupNCCL, |
| "Options", |
| processGroupOptions, |
| R"( |
| ProcessGroup options for the NCCL backend |
| |
| Arguments: |
| is_high_priority_stream (bool, optional): flag to enable/disable process |
| group to pick up high priority cuda streams. It lets CUDA driver |
| to prioritize NCCL kernels when there are compute kernels waiting. |
| Default is False. |
| |
| Attributes: |
| config (NCCLConfig): configures NCCL communicators (only avaiable for |
| builds using NCCL 2.17+). This can be used to improve |
| communication-computation overlap for NCCL kernels by tuning |
| available parameters in the config. See |
| https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t |
| for details. |
| |
| Example:: |
| >>> import torch.distributed as dist |
| >>> |
| >>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True) |
| >>> # For builds using NCCL 2.17+, configure communicators |
| >>> nccl_options.config.cga_cluster_size = 2 |
| >>> nccl_options.config.max_ctas = 4 |
| >>> nccl_options.config.min_ctas = 2 |
| >>> # initialize a nccl process group with the options just created |
| >>> dist.init_process_group("nccl", pg_options=nccl_options) |
| )") |
| .def(py::init<bool>(), py::arg("is_high_priority_stream") = false) |
| #ifdef NCCL_HAS_COMM_CTA_CGA |
| .def_readwrite( |
| "is_high_priority_stream", |
| &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream) |
| .def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config); |
| #else |
| .def_readwrite( |
| "is_high_priority_stream", |
| &::c10d::ProcessGroupNCCL::Options::is_high_priority_stream); |
| #endif |
| |
| #endif |
| |
| #ifdef USE_C10D_MPI |
| auto processGroupMPI = |
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>( |
| module, "ProcessGroupMPI", backend); |
| |
| // Define static create function instead of a constructor, because |
| // this function may return null. This happens if this process is not |
| // part of a sub group that is to be created. |
| processGroupMPI.def_static( |
| "create", |
| [](std::vector<int> ranks) { |
| return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); |
| }, |
| py::call_guard<py::gil_scoped_release>()); |
| #endif |
| |
| #ifdef USE_C10D_UCC |
| auto processGroupUCC = |
| intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>( |
| module, "ProcessGroupUCC", backend) |
| .def( |
| py::init([](const c10::intrusive_ptr<::c10d::Store>& store, |
| int rank, |
| int size, |
| const std::chrono::milliseconds& timeout) { |
| return c10::make_intrusive<::c10d::ProcessGroupUCC>( |
| store, rank, size, timeout); |
| }), |
| py::arg("store"), |
| py::arg("rank"), |
| py::arg("size"), |
| py::arg("timeout") = kProcessGroupDefaultTimeout, |
| py::call_guard<py::gil_scoped_release>()); |
| #endif |
| |
| py::class_< |
| ::c10d::Work, |
| c10::intrusive_ptr<::c10d::Work>, |
| ::c10d::PyProcessGroup::PyWork>(module, "Work") |
| .def(py::init<>()) |
| .def("is_completed", &::c10d::Work::isCompleted) |
| .def( |
| "is_success", |
| [](::c10d::Work& work) -> bool { |
| TORCH_WARN_ONCE( |
| fmt::format(kDeprecationWarning, "Work::is_success")); |
| return work.isSuccess(); |
| }) |
| .def( |
| "exception", |
| [](::c10d::Work& work) -> std::exception_ptr { |
| TORCH_WARN_ONCE( |
| fmt::format(kDeprecationWarning, "Work::exception")); |
| return work.exception(); |
| }) |
| .def( |
| "source_rank", |
| [](::c10d::Work& work) -> int { |
| TORCH_WARN_ONCE( |
| fmt::format(kDeprecationWarning, "Work::source_rank")); |
| return work.sourceRank(); |
| }) |
| .def("_source_rank", &::c10d::Work::sourceRank) |
| .def( |
| "result", |
| [](::c10d::Work& work) -> std::vector<at::Tensor> { |
| return work.result(); |
| }) |
| .def( |
| "synchronize", |
| [](::c10d::Work& work) -> void { |
| TORCH_WARN_ONCE( |
| fmt::format(kDeprecationWarning, "Work::synchronize")); |
| work.synchronize(); |
| }) |
| .def( |
| "wait", |
| &::c10d::Work::wait, |
| py::arg("timeout") = kNoTimeout, |
| py::call_guard<py::gil_scoped_release>()) |
| .def( |
| "get_future", |
| [](::c10d::Work& work) -> std::shared_ptr<jit::PythonFutureWrapper> { |
| return std::make_shared<jit::PythonFutureWrapper>(work.getFuture()); |
| }, |
| R"( |
| Returns: |
| A ``torch.futures.Future`` object which is associated with the completion of |
| the ``Work``. As an example, a future object can be retrieved |
| by ``fut = process_group.allreduce(tensors).get_future()``. |
| |
| Example:: |
| Below is an example of a simple allreduce DDP communication hook that uses |
| ``get_future` API to retrieve a Future associated with the completion of |
| ``allreduce``. |
| |
| >>> def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -> torch.futures.Future |
| >>> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD |
| >>> tensor = bucket.buffer().div_(group_to_use.size()) |
| >>> return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future() |
| >>> ddp_model.register_comm_hook(state=None, hook=allreduce) |
| |
| .. warning :: |
| ``get_future`` API supports NCCL, and partially GLOO and MPI backends |
| (no support for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``. |
| |
| In the example above, ``allreduce`` work will be done on GPU using NCCL backend, |
| ``fut.wait()`` will return after synchronizing the appropriate NCCL streams |
| with PyTorch's current device streams to ensure we can have asynchronous CUDA |
| execution and it does not wait for the entire operation to complete on GPU. Note that |
| ``CUDAFuture`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. |
| In addition, if a callback function was added by ``fut.then()``, it will wait until |
| ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback |
| stream and invoke the callback inline after running the callback on the callback stream. |
| ``fut.then()`` will return another ``CUDAFuture`` that holds the return value of the |
| callback and a ``CUDAEvent`` that recorded the callback stream. |
| |
| 1. For CPU work, ``fut.done()`` returns true when work has been completed and value() |
| tensors are ready. |
| 2. For GPU work, ``fut.done()`` returns true only whether the operation has been enqueued. |
| 3. For mixed CPU-GPU work (e.g. sending GPU tensors with GLOO), ``fut.done()`` returns |
| true when tensors have arrived on respective nodes, but not yet necessarily synched on |
| respective GPUs (similarly to GPU work). |
| )"); |
| |
| py::class_<c10::DDPLoggingData>(module, "DDPLoggingData") |
| .def(py::init<>()) |
| .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map) |
| .def_readwrite("ints_map", &c10::DDPLoggingData::ints_map); |
| |
| module.def( |
| "_compute_bucket_assignment_by_size", |
| [](const std::vector<at::Tensor>& tensors, |
| const std::vector<size_t>& bucket_size_limits, |
| const std::vector<bool>& expect_sparse_gradient, |
| const std::vector<int64_t>& tensor_indices, |
| const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) { |
| if (logger.has_value()) { |
| std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); |
| return ::c10d::compute_bucket_assignment_by_size( |
| tensors, |
| bucket_size_limits, |
| expect_sparse_gradient, |
| tensor_indices, |
| {logger_weakref}); |
| } else { |
| return ::c10d::compute_bucket_assignment_by_size( |
| tensors, |
| bucket_size_limits, |
| expect_sparse_gradient, |
| tensor_indices, |
| {}); |
| } |
| }, |
| py::arg("tensors"), |
| py::arg("bucket_size"), |
| py::arg("expect_sparse_gradient") = std::vector<bool>(), |
| py::arg("tensor_indices") = std::vector<int64_t>(), |
| py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{}, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_verify_params_across_processes", |
| [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, |
| const std::vector<at::Tensor>& params, |
| const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) { |
| if (logger.has_value()) { |
| std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); |
| verify_params_across_processes( |
| process_group, params, {logger_weakref}); |
| } else { |
| verify_params_across_processes(process_group, params, {}); |
| } |
| }, |
| py::arg("process_group"), |
| py::arg("params"), |
| py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{}, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_broadcast_coalesced", |
| // Define a lambda such that the pybind11 prototype can take a std::vector |
| // for the tensor list argument, but still pass it to the underlying |
| // function as a c10::ArrayRef. |
| [](c10::intrusive_ptr<::c10d::ProcessGroup> process_group, |
| std::vector<at::Tensor> tensors, // NOLINT |
| size_t buffer_size, |
| int rank) { |
| broadcast_coalesced( |
| std::move(process_group), tensors, buffer_size, rank); |
| }, |
| py::arg("process_group"), |
| py::arg("tensors"), |
| py::arg("buffer_size"), |
| // The source of truth rank to broadcast the tensors from. |
| py::arg("src") = 0, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.def( |
| "_test_python_store", |
| // Define a function that takes a c10d store and runs a few tests. |
| // This is used by the PythonStore tests, which we cannot test from the |
| // Python side of the world. Calling Python functions on a Python object |
| // completely bypasses pybind11. We need to test that the overloaded |
| // functions call into Python and behave like we expect. |
| [](c10::intrusive_ptr<::c10d::Store> store) { |
| auto add = [&store](const std::string& key, int64_t value) { |
| store->add(key, value); |
| }; |
| |
| auto set = [&store](const std::string& key, const std::string& value) { |
| store->set(key, value); |
| }; |
| |
| auto get = [&store](const std::string& key) { |
| auto value = store->get(key); |
| return std::string(value.begin(), value.end()); |
| }; |
| |
| add("key", 1); |
| add("key", 2); |
| add("key", 3); |
| set("key0", "value0"); |
| add("key3", 1); |
| set("key1", "value1"); |
| add("key3", 2); |
| set("key2", "value2"); |
| add("key3", 3); |
| add("key3", 4); |
| add("key3", 3); |
| add("key3", 2); |
| if (get("key") != "6") { |
| TORCH_CHECK(false, "assertion failed"); |
| } |
| if (get("key0") != "value0") { |
| TORCH_CHECK(false, "assertion failed"); |
| } |
| if (get("key1") != "value1") { |
| TORCH_CHECK(false, "assertion failed"); |
| } |
| if (get("key2") != "value2") { |
| TORCH_CHECK(false, "assertion failed"); |
| } |
| if (get("key3") != "15") { |
| TORCH_CHECK(false, "assertion failed"); |
| } |
| }, |
| py::call_guard<py::gil_scoped_release>()); |
| |
| module.attr("_DEFAULT_FIRST_BUCKET_BYTES") = ::c10d::kDefaultFirstBucketBytes; |
| module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout); |
| module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout); |
| |
| module.def( |
| "_create_work_from_future", |
| [](std::shared_ptr<jit::PythonFutureWrapper> future) { |
| return ::c10d::Work::create_from_future(future->fut); |
| }, |
| py::arg("future"), |
| R"( |
| Arguments: |
| future(str): The future to wrap. |
| Returns: |
| A ``Work`` object which is associated with the completion of |
| the ``torch.futures.Future``. |
| This is the preferred way of constructing Work objects when writing a custom ProcessGroup |
| in python. |
| Example:: |
| >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): |
| >>> def broadcast(self, tensor_list, opts): |
| >>> fut = torch.futures.Future() |
| >>> fut.set_result(tensor_list) |
| >>> return torch._C._distributed_c10d._create_work_from_future(fut) |
| .. warning :: |
| This API is experimental and subject to change. |
| The returned Work object has multiple limitations: |
| - synchronize() does nothing. Use ``torch.futures.Future`` based synchronization. |
| - wait() ignored timeout argument. |
| - sourceRank() raises. |
| - abort() raises. |
| The provided Future object result must be a Tensor or a list of Tensors. |
| )"); |
| |
| Py_RETURN_TRUE; |
| } |
| |
| #undef PROCESS_GROUP_DEPRECATION_WARNING |
| |
| } // namespace |
| |
| // c10d methods on torch._C |
| static PyMethodDef methods[] = { // NOLINT |
| {"_c10d_init", c10d_init, METH_NOARGS, nullptr}, |
| {nullptr, nullptr, 0, nullptr}}; |
| |
| PyMethodDef* python_functions() { |
| return methods; |
| } |
| |
| } // namespace c10d |
| } // namespace distributed |
| } // namespace torch |