| #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp> |
| |
| #ifdef USE_C10D_MPI |
| |
| #include <iostream> |
| #include <limits> |
| #include <map> |
| |
| #include <c10/core/DeviceGuard.h> |
| #include <c10/util/irange.h> |
| |
| #if defined(OPEN_MPI) && OPEN_MPI |
| #include <mpi-ext.h> // Needed for CUDA-aware check |
| #endif |
| |
| namespace c10d { |
| |
| #define MPI_CHECK(cmd) \ |
| do { \ |
| int mpiStatus = cmd; \ |
| if (mpiStatus != MPI_SUCCESS) { \ |
| std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ |
| std::to_string(__LINE__) + \ |
| ", with error code: " + std::to_string(mpiStatus); \ |
| TORCH_CHECK(false, err); \ |
| } \ |
| } while (0) |
| |
| namespace { |
| |
| // Op mapping |
| std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = { |
| {ReduceOp::MIN, MPI_MIN}, |
| {ReduceOp::MAX, MPI_MAX}, |
| {ReduceOp::SUM, MPI_SUM}, |
| {ReduceOp::PRODUCT, MPI_PROD}, |
| }; |
| // Type mapping |
| std::map<at::ScalarType, MPI_Datatype> mpiDatatype = { |
| {at::kByte, MPI_UNSIGNED_CHAR}, |
| {at::kChar, MPI_CHAR}, |
| {at::kDouble, MPI_DOUBLE}, |
| {at::kFloat, MPI_FLOAT}, |
| {at::kInt, MPI_INT}, |
| {at::kLong, MPI_LONG}, |
| {at::kShort, MPI_SHORT}, |
| }; |
| |
| // Checking CUDA-aware MPI support, currently we only support CUDA aware |
| // MPI ops through Open MPI |
| bool cudaAwareMpiCheck() { |
| // Run time check |
| #if defined(MPIX_CUDA_AWARE_SUPPORT) |
| if (MPIX_Query_cuda_support() == 1) { |
| return true; |
| } else { |
| return false; |
| } |
| #else // !defined(MPIX_CUDA_AWARE_SUPPORT) |
| return false; |
| #endif // MPIX_CUDA_AWARE_SUPPORT |
| } |
| |
| // Checking the input tensor's validity |
| void checkSingleTensorHelper(const at::Tensor& tensor) { |
| if (!tensor.is_contiguous()) { |
| TORCH_CHECK(false, "input tensor has to be contiguous"); |
| } |
| if (tensor.is_sparse()) { |
| TORCH_CHECK(false, "input tensor has to be dense"); |
| } |
| if (tensor.is_cuda() && !cudaAwareMpiCheck()) { |
| TORCH_CHECK( |
| false, |
| "CUDA tensor detected and the MPI used doesn't " |
| "have CUDA-aware MPI support"); |
| } |
| } |
| |
| void checkSingleTensor(const std::vector<at::Tensor>& tensors) { |
| if (tensors.size() != 1) { |
| TORCH_CHECK( |
| false, "MPI process group does not support multi-GPU collectives"); |
| } |
| checkSingleTensorHelper(tensors[0]); |
| } |
| |
| void checkSameSizeAndType( |
| const at::Tensor& t_in, |
| const std::vector<at::Tensor>& tensors) { |
| for (const auto& tensor : tensors) { |
| if ((tensor.numel() != t_in.numel()) || |
| (tensor.scalar_type() != t_in.scalar_type())) { |
| TORCH_CHECK(false, "Tensors are not equal in size or data type"); |
| } |
| checkSingleTensorHelper(tensor); |
| } |
| } |
| |
| } // namespace |
| |
| std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() { |
| return outputTensors_; |
| } |
| |
| c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() { |
| return future_; |
| } |
| |
| void ProcessGroupMPI::WorkMPI::finishWorkMPIError(std::exception_ptr eptr) { |
| future_->setError(eptr); |
| finish(eptr); |
| } |
| |
| void ProcessGroupMPI::WorkMPI::finishWorkMPI() { |
| future_->markCompleted(at::IValue(outputTensors_)); |
| finish(); |
| } |
| |
| ProcessGroupMPI::AsyncWork::AsyncWork( |
| MPI_Request request, |
| std::vector<at::Tensor> outputTensors, |
| const char* profilingTitle, |
| const c10::optional<std::vector<at::Tensor>>& inputTensors) |
| : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), |
| outputTensors_(std::move(outputTensors)), |
| request_(request) { |
| memset(&status_, 0, sizeof(status_)); |
| } |
| |
| ProcessGroupMPI::AsyncWork::~AsyncWork() { |
| if (request_ != MPI_REQUEST_NULL) { |
| std::cerr |
| << "Attempted destruction of AsyncWork before work has completed, " |
| << "terminating the program." << std::endl; |
| std::terminate(); |
| } |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::isCompleted() { |
| if (request_ == MPI_REQUEST_NULL) { |
| return true; |
| } |
| |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| int flag = 0; |
| MPI_CHECK(MPI_Test(&request_, &flag, &status_)); |
| if (request_ != MPI_REQUEST_NULL) { |
| return false; |
| } |
| |
| // request_ == MPI_REQUEST_NULL; the work has completed |
| // Populate exception if request was not successful |
| if (status_.MPI_ERROR != MPI_SUCCESS) { |
| populateException(); |
| } |
| |
| return true; |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::isSuccess() const { |
| if (request_ != MPI_REQUEST_NULL) { |
| TORCH_CHECK( |
| false, |
| "Invalid call to AsyncWork::isSuccess before work has completed"); |
| } |
| |
| return status_.MPI_ERROR == MPI_SUCCESS; |
| } |
| |
| int ProcessGroupMPI::AsyncWork::sourceRank() const { |
| return status_.MPI_SOURCE; |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { |
| if (request_ == MPI_REQUEST_NULL) { |
| // AsyncWork needs to manually call profiling end callbacks if they are set, |
| // since it does not call ProcessGroup::finish(). |
| if (Work::recordFunctionEndCallback_) { |
| Work::recordFunctionEndCallback_(); |
| Work::recordFunctionEndCallback_ = nullptr; |
| } |
| return true; |
| } |
| |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Wait(&request_, &status_)); |
| auto ok = (status_.MPI_ERROR == MPI_SUCCESS); |
| |
| // AsyncWork needs to manually call profiling end callbacks if they are set, |
| // since it does not call ProcessGroup::finish(). |
| if (Work::recordFunctionEndCallback_) { |
| Work::recordFunctionEndCallback_(); |
| Work::recordFunctionEndCallback_ = nullptr; |
| } |
| |
| if (!ok) { |
| populateException(); |
| std::rethrow_exception(exception_); |
| } |
| // Always return true, because abort API is not implemented. |
| return true; |
| } |
| |
| void ProcessGroupMPI::AsyncWork::abort(){ |
| TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")} |
| |
| std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() { |
| return outputTensors_; |
| } |
| |
| void ProcessGroupMPI::AsyncWork::populateException() { |
| std::array<char, MPI_MAX_ERROR_STRING> buf; |
| int len = buf.size(); |
| MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len)); |
| exception_ = |
| std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len))); |
| } |
| |
| // Static global states |
| int ProcessGroupMPI::mpiThreadSupport_ = 0; |
| std::mutex ProcessGroupMPI::pgGlobalMutex_; |
| // We only want to initialize once |
| c10::once_flag ProcessGroupMPI::onceFlagInitMPI; |
| |
| void ProcessGroupMPI::mpiExit() { |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Finalize()); |
| } |
| |
| void ProcessGroupMPI::initMPIOnce() { |
| // Initialize MPI environment |
| c10::call_once(onceFlagInitMPI, []() { |
| int mpi_was_initialized = 0; |
| MPI_CHECK(MPI_Initialized(&mpi_was_initialized)); |
| if (mpi_was_initialized == 0) { |
| MPI_CHECK(MPI_Init_thread( |
| nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); |
| if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { |
| TORCH_CHECK( |
| false, |
| "Used MPI implementation doesn't have the " |
| "minimum level of threading support: " |
| "MPI_THREAD_SERIALIZED. This is required by " |
| "c10d package"); |
| } |
| if (std::atexit(ProcessGroupMPI::mpiExit)) { |
| TORCH_CHECK(false, "Fail to register the MPI exit handler"); |
| } |
| } else { |
| TORCH_WARN_ONCE("MPI was previously initialized."); |
| } |
| }); |
| } |
| |
| c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI( |
| std::vector<int> ranks) { |
| // Once initialization |
| initMPIOnce(); |
| |
| MPI_Comm groupComm = MPI_COMM_WORLD; |
| int rank = -1; |
| int size = -1; |
| |
| { |
| std::lock_guard<std::mutex> globalLock(pgGlobalMutex_); |
| |
| // If no ranks are specified, assume we're creating the root group |
| if (!ranks.empty()) { |
| MPI_Group worldGroup; |
| MPI_Group ranksGroup; |
| MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); |
| MPI_CHECK( |
| MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); |
| // `MPI_Comm_create` can be flaky in certain cases. |
| // See: https://github.com/pytorch/pytorch/issues/53899 |
| constexpr int kMaxNumRetries = 3; |
| bool groupComm_updated = false; |
| MPI_Barrier(MPI_COMM_WORLD); |
| for (const auto i : c10::irange(kMaxNumRetries)) { |
| (void)i; |
| if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) { |
| groupComm_updated = true; |
| break; |
| } |
| } |
| MPI_CHECK(groupComm_updated); |
| MPI_CHECK(MPI_Group_free(&worldGroup)); |
| MPI_CHECK(MPI_Group_free(&ranksGroup)); |
| } |
| |
| // Fetch rank and world size for this group (MPI_COMM_WORLD or new) |
| if (groupComm != MPI_COMM_NULL) { |
| MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); |
| MPI_CHECK(MPI_Comm_size(groupComm, &size)); |
| |
| if (rank < 0 || size < 0) { |
| TORCH_CHECK(false, "Failed to get the world_size / rank"); |
| } |
| } |
| } |
| |
| // If this process is not part of the group, we don't construct a |
| // process group instance. This is in line with the semantics of the |
| // other process group types. |
| if (groupComm == MPI_COMM_NULL) { |
| return c10::intrusive_ptr<ProcessGroupMPI>(); |
| } |
| |
| return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm); |
| } |
| |
| ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) |
| : Backend(rank, size), stop_(false), pgComm_(pgComm) { |
| if (pgComm_ == MPI_COMM_NULL) { |
| TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL"); |
| } |
| |
| // Start the worker thread accepting MPI calls |
| workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); |
| |
| init(); |
| } |
| |
| ProcessGroupMPI::~ProcessGroupMPI() { |
| destroy(); |
| } |
| |
| void ProcessGroupMPI::destroy() { |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| queueConsumeCV_.wait(lock, [&] { return queue_.empty(); }); |
| |
| // Queue is empty, signal stop |
| stop_ = true; |
| |
| // Release lock to allow threads to terminate |
| lock.unlock(); |
| queueProduceCV_.notify_all(); |
| |
| // Join the single worker thread |
| workerThread_.join(); |
| } |
| |
| void ProcessGroupMPI::abort() { |
| destroy(); |
| MPI_Abort(pgComm_, EXIT_FAILURE); |
| } |
| |
| void ProcessGroupMPI::runLoop() { |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| |
| while (!stop_) { |
| if (queue_.empty()) { |
| queueProduceCV_.wait(lock); |
| continue; |
| } |
| |
| auto workTuple = std::move(queue_.front()); |
| |
| queue_.pop_front(); |
| |
| auto& workEntry = std::get<0>(workTuple); |
| auto& work = std::get<1>(workTuple); |
| |
| lock.unlock(); |
| queueConsumeCV_.notify_one(); |
| |
| try { |
| workEntry->run(workEntry); |
| work->finishWorkMPI(); |
| } catch (...) { |
| work->finishWorkMPIError(std::current_exception()); |
| } |
| |
| lock.lock(); |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue( |
| std::unique_ptr<WorkEntry> entry, |
| const char* profilingTitle, |
| const c10::optional<std::vector<at::Tensor>>& inputTensors) { |
| auto work = |
| c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors); |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| queue_.push_back(std::make_tuple(std::move(entry), work)); |
| lock.unlock(); |
| queueProduceCV_.notify_one(); |
| return work; |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast( |
| std::vector<at::Tensor>& tensors, |
| const BroadcastOptions& opts) { |
| checkSingleTensor(tensors); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Bcast( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| auto entry = |
| std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:broadcast", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceOptions& opts) { |
| checkSingleTensor(tensors); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Allreduce( |
| MPI_IN_PLACE, |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| mpiOp.at(opts.reduceOp), |
| pgComm_)); |
| }; |
| auto entry = |
| std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:all_reduce", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceCoalescedOptions& opts) { |
| TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::reduce( |
| std::vector<at::Tensor>& tensors, |
| const ReduceOptions& opts) { |
| checkSingleTensor(tensors); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| auto dataPtr = (entry->src)[0].data_ptr(); |
| void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; |
| void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Reduce( |
| sendbuf, |
| recvbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| mpiOp.at(opts.reduceOp), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| auto entry = |
| std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:reduce", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts) { |
| checkSingleTensor(inputTensors); |
| if (outputTensors.size() != 1) { |
| TORCH_CHECK( |
| false, |
| "MPI process group only supports a single " |
| "tensor op"); |
| } |
| if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
| TORCH_CHECK( |
| false, |
| "All gather: number of output tensors should equal " |
| "to the world size"); |
| } |
| |
| checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| std::vector<at::Tensor> outputDataVec = entry->dst; |
| auto flatOutputTensor = newLikeFlat(outputDataVec); |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Allgather( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| flatOutputTensor.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| pgComm_)); |
| |
| for (const auto i : c10::irange(outputDataVec.size())) { |
| outputDataVec[i].copy_(flatOutputTensor[i]); |
| } |
| }; |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors, &outputTensors[0], std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:all_gather", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced( |
| std::vector<std::vector<at::Tensor>>& /* unused */, |
| std::vector<at::Tensor>& /* unused */, |
| const AllgatherOptions& /* unused */) { |
| TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::gather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const GatherOptions& opts) { |
| checkSingleTensor(inputTensors); |
| |
| if (rank_ != opts.rootRank) { |
| if (outputTensors.size() > 0) { |
| TORCH_CHECK( |
| false, |
| "Gather: number of output tensors should be 0 " |
| "for non-root"); |
| } |
| } else { |
| if (outputTensors.size() != 1) { |
| TORCH_CHECK(false, "Gather: multi-GPU collective is not supported"); |
| } |
| if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
| TORCH_CHECK( |
| false, |
| "Gather: number of output tensors should equal " |
| "to the world size"); |
| } |
| checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
| } |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| void* recvbuf = nullptr; |
| at::Tensor flatOutputTensor; |
| |
| std::vector<at::Tensor> dstdata = entry->dst; |
| if (rank_ == opts.rootRank) { |
| flatOutputTensor = newLikeFlat(dstdata); |
| recvbuf = flatOutputTensor.data_ptr(); |
| } |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Gather( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| recvbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| |
| if (rank_ == opts.rootRank) { |
| const std::vector<at::Tensor>& outputDataVec = entry->dst; |
| // copy the flattened output tensors to the outputs |
| for (const auto i : c10::irange(outputDataVec.size())) { |
| outputDataVec.at(i).copy_(flatOutputTensor[i]); |
| } |
| } |
| }; |
| |
| if (rank_ == opts.rootRank) { |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors, &outputTensors[0], std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:gather", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } else { |
| auto entry = |
| std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:gather", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ScatterOptions& opts) { |
| checkSingleTensor(outputTensors); |
| |
| if (rank_ != opts.rootRank) { |
| if (inputTensors.size() > 0) { |
| TORCH_CHECK( |
| false, |
| "Scatter: number of input tensors should be 0 " |
| "for non-root"); |
| } |
| } else { |
| if (inputTensors.size() != 1) { |
| TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported"); |
| } |
| if (static_cast<size_t>(size_) != inputTensors[0].size()) { |
| TORCH_CHECK( |
| false, |
| "Scatter: number of input tensors should equal " |
| "to the world size"); |
| } |
| checkSameSizeAndType(outputTensors[0], inputTensors[0]); |
| } |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->dst)[0]; |
| void* sendbuf = nullptr; |
| at::Tensor flatInputTensor; |
| |
| if (rank_ == opts.rootRank) { |
| std::vector<at::Tensor>& inputDataVec = entry->src; |
| flatInputTensor = newLikeFlat(inputDataVec); |
| sendbuf = flatInputTensor.data_ptr(); |
| |
| // copy the input tensors to the flatten large send buffer |
| for (const auto i : c10::irange(inputDataVec.size())) { |
| flatInputTensor[i].copy_(inputDataVec.at(i)); |
| } |
| } |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Scatter( |
| sendbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| |
| if (rank_ == opts.rootRank) { |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors[0], &outputTensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:scatter", |
| inputTensors.size() > 0 |
| ? c10::optional<std::vector<at::Tensor>>(inputTensors[0]) |
| : c10::nullopt); |
| } else { |
| auto entry = std::make_unique<WorkEntry>( |
| nullptr, &outputTensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:scatter", |
| inputTensors.size() > 0 |
| ? c10::optional<std::vector<at::Tensor>>(inputTensors[0]) |
| : c10::nullopt); |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ReduceScatterOptions& opts) { |
| TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| std::vector<int64_t>& outputSplitSizes, |
| std::vector<int64_t>& inputSplitSizes, |
| const AllToAllOptions& opts) { |
| checkSingleTensorHelper(inputTensor); |
| checkSingleTensorHelper(outputTensor); |
| |
| if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { |
| // We can use alltoall |
| TORCH_CHECK( |
| outputTensor.numel() == inputTensor.numel() && |
| outputTensor.type() == inputTensor.type(), |
| "Tensors are not equal in size or data type"); |
| TORCH_CHECK( |
| outputTensor.size(0) % size_ == 0, |
| "Tensor's dim 0 does not divide equally across group size"); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| auto srcdata = (entry->src)[0]; |
| auto dstdata = (entry->dst)[0]; |
| c10::DeviceGuard guard(srcdata.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoall( |
| srcdata.data_ptr(), |
| srcdata.numel() / size_, |
| mpiDatatype.at(srcdata.scalar_type()), |
| dstdata.data_ptr(), |
| dstdata.numel() / size_, |
| mpiDatatype.at(dstdata.scalar_type()), |
| pgComm_)); |
| }; |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors, &outputTensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:all_to_all", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } else { |
| // Need alltoallv |
| c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); |
| c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this, inputSplitSizes, outputSplitSizes]( |
| std::unique_ptr<WorkEntry>& entry) { |
| auto srcdata = (entry->src)[0]; |
| auto dstdata = (entry->dst)[0]; |
| std::vector<int> send_lengths(size_); |
| std::vector<int> recv_lengths(size_); |
| std::vector<int> send_offsets(size_); |
| std::vector<int> recv_offsets(size_); |
| c10d::computeLengthsAndOffsets( |
| inputSplitSizes, srcdata, &send_lengths, &send_offsets); |
| c10d::computeLengthsAndOffsets( |
| outputSplitSizes, dstdata, &recv_lengths, &recv_offsets); |
| c10::DeviceGuard guard(srcdata.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoallv( |
| srcdata.data_ptr(), |
| send_lengths.data(), |
| send_offsets.data(), |
| mpiDatatype.at(srcdata.scalar_type()), |
| dstdata.data_ptr(), |
| recv_lengths.data(), |
| recv_offsets.data(), |
| mpiDatatype.at(dstdata.scalar_type()), |
| pgComm_)); |
| }; |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors, &outputTensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:all_to_all", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllToAllOptions& opts) { |
| TORCH_CHECK( |
| inputTensors.size() == static_cast<size_t>(size_), |
| "Number of input tensors are not equal to group size"); |
| TORCH_CHECK( |
| outputTensors.size() == static_cast<size_t>(size_), |
| "Number of output tensors are not equal to group size"); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| std::vector<int> send_lengths(size_); |
| std::vector<int> recv_lengths(size_); |
| std::vector<int> send_offsets(size_); |
| std::vector<int> recv_offsets(size_); |
| auto srcdata = entry->src; |
| auto dstdata = entry->dst; |
| int64_t src_len = c10d::computeLengthsAndOffsets( |
| srcdata, &send_lengths, &send_offsets); |
| int64_t dst_len = c10d::computeLengthsAndOffsets( |
| dstdata, &recv_lengths, &recv_offsets); |
| std::vector<int64_t> send_lengthsL( |
| send_lengths.begin(), send_lengths.end()); |
| std::vector<int64_t> recv_lengthsL( |
| recv_lengths.begin(), recv_lengths.end()); |
| at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options()); |
| at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options()); |
| auto srcFlatDataSplits = |
| srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); |
| for (const auto i : c10::irange(size_)) { |
| srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); |
| } |
| c10::DeviceGuard guard1(srcdata[0].device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoallv( |
| srcFlatData.data_ptr(), |
| send_lengths.data(), |
| send_offsets.data(), |
| mpiDatatype.at(srcdata[0].scalar_type()), |
| dstFlatData.data_ptr(), |
| recv_lengths.data(), |
| recv_offsets.data(), |
| mpiDatatype.at(dstdata[0].scalar_type()), |
| pgComm_)); |
| |
| auto dstFlatDataSplits = |
| dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); |
| for (const auto i : c10::irange(size_)) { |
| dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); |
| } |
| }; |
| auto entry = std::make_unique<WorkEntry>( |
| &inputTensors, &outputTensors, std::move(runFunc)); |
| return enqueue( |
| std::move(entry), |
| "mpi:all_to_all", |
| c10::optional<std::vector<at::Tensor>>(inputTensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Isend( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| dstRank, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return c10::make_intrusive<AsyncWork>( |
| request, |
| std::vector<at::Tensor>(), |
| "mpi:send", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Irecv( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| srcRank, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return c10::make_intrusive<AsyncWork>( |
| request, |
| tensors, |
| "mpi:recv", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource( |
| std::vector<at::Tensor>& tensors, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Irecv( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| MPI_ANY_SOURCE, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return c10::make_intrusive<AsyncWork>( |
| request, |
| tensors, |
| "mpi:recvAnySource", |
| c10::optional<std::vector<at::Tensor>>(tensors)); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) { |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Barrier(pgComm_)); |
| }; |
| auto entry = |
| std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc)); |
| return enqueue(std::move(entry), "mpi:barrier", c10::nullopt); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base( |
| at::Tensor& /*unused */, |
| at::Tensor& /*unused */, |
| const AllgatherOptions& /*unused */) { |
| TORCH_CHECK(false, "no support for _allgather_base in MPI process group"); |
| } |
| |
| } // namespace c10d |
| |
| #endif // USE_C10D_MPI |