| #ifdef USE_C10D_UCC |
| |
| #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp> |
| #include <torch/csrc/distributed/c10d/UCCTracing.hpp> |
| #include <torch/csrc/distributed/c10d/UCCUtils.hpp> |
| #include <list> |
| #include <memory> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| namespace c10d { |
| |
| namespace { |
| constexpr int64_t kBusyWaitMillis = 10; |
| |
| const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = { |
| {c10::kCPU, UCC_MEMORY_TYPE_HOST}, |
| {c10::kCUDA, UCC_MEMORY_TYPE_CUDA}, |
| }; |
| |
| ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) { |
| if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end()) |
| return ucc_mtype_map.at(_c10_type); |
| else |
| return UCC_MEMORY_TYPE_UNKNOWN; |
| } |
| |
| const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = { |
| {at::kByte, UCC_DT_UINT8}, |
| {at::kChar, UCC_DT_INT8}, |
| {at::kHalf, UCC_DT_FLOAT16}, |
| {at::kBFloat16, UCC_DT_BFLOAT16}, |
| {at::kDouble, UCC_DT_FLOAT64}, |
| {at::kFloat, UCC_DT_FLOAT32}, |
| {at::kInt, UCC_DT_INT32}, |
| {at::kLong, UCC_DT_INT64}, |
| {at::kBool, UCC_DT_UINT8}, |
| }; |
| |
| ucc_datatype_t to_ucc_dType(at::Tensor _tensor) { |
| if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) { |
| TORCH_CHECK( |
| false, "Size of Boolean type larger than 1 is not supported in UCC"); |
| } |
| try { |
| return ucc_dtype_map.at(_tensor.scalar_type()); |
| } catch (const std::out_of_range& e) { |
| TORCH_CHECK(false, "Not supported data type for UCC"); |
| } |
| } |
| |
| const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = { |
| {ReduceOp::SUM, UCC_OP_SUM}, |
| {ReduceOp::PRODUCT, UCC_OP_PROD}, |
| {ReduceOp::MIN, UCC_OP_MIN}, |
| {ReduceOp::MAX, UCC_OP_MAX}, |
| {ReduceOp::BAND, UCC_OP_BAND}, |
| {ReduceOp::BOR, UCC_OP_BOR}, |
| {ReduceOp::BXOR, UCC_OP_BXOR}, |
| {ReduceOp::AVG, UCC_OP_AVG}, |
| }; |
| |
| ucc_reduction_op_t to_ucc_reduceOp( |
| const ReduceOp _op, |
| const at::ScalarType _dt) { |
| if (_dt == at::kBool) { |
| if (_op == ReduceOp::SUM) { |
| // bitwise or |
| return UCC_OP_MAX; |
| } else if (_op == ReduceOp::PRODUCT) { |
| // bitwise and |
| return UCC_OP_MIN; |
| } else if (_op == ReduceOp::AVG) { |
| TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs"); |
| } |
| } |
| |
| try { |
| return ucc_op_map.at(_op); |
| } catch (const std::out_of_range& e) { |
| TORCH_CHECK(false, "Not supported ReduceOp for UCC"); |
| } |
| } |
| |
| struct torch_ucc_config_t { |
| c10::once_flag flag; |
| std::array<bool, 32> blocking_wait; |
| bool enable_comms_logger; |
| bool use_future; |
| // Sharing UCC communicator among multiple PGs to save resource. |
| bool shared_comm; |
| // Using allgatherv to achieve allgather, without flattening the list of |
| // (potentially non-contiguous) tensors. |
| bool use_allgatherv; |
| bool enable_health_check; |
| } torch_ucc_config; |
| |
| std::unordered_map<std::string, std::string> torch_ucc_envs_map = { |
| // TORCH_UCC_BLOCKING_WAIT allowed syntax: |
| // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled |
| // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled |
| // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled |
| // on selected operations |
| // Supported operations: |
| // [allgather,allgather_base,allreduce,alltoall,broadcast, |
| // gather,reduce,reduce_scatter,scatter,send,recv] |
| {"TORCH_UCC_BLOCKING_WAIT", "none"}, |
| |
| {"TORCH_UCC_USE_FUTURE", "1"}, |
| {"TORCH_UCC_PROFILING_ENABLE", "0"}, |
| {"TORCH_UCC_SHARED_COMM", "1"}, |
| {"TORCH_UCC_USE_ALLGATHERV", "0"}, |
| {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"}, |
| {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"}, |
| }; |
| |
| std::vector<OpType> parse_blocking_wait(std::string op_list_string) { |
| const static std::unordered_map<std::string, OpType> str2op = { |
| {"allgather", OpType::ALLGATHER}, |
| {"allgather_base", OpType::_ALLGATHER_BASE}, |
| {"allreduce", OpType::ALLREDUCE}, |
| {"alltoall_base", OpType::ALLTOALL_BASE}, |
| {"broadcast", OpType::BROADCAST}, |
| {"gather", OpType::GATHER}, |
| {"reduce", OpType::REDUCE}, |
| {"reduce_scatter", OpType::REDUCE_SCATTER}, |
| {"scatter", OpType::SCATTER}, |
| {"send", OpType::SEND}, |
| {"recv", OpType::RECV}, |
| }; |
| auto op_list = parse_list(op_list_string); |
| if (op_list == std::vector<std::string>{"none"}) { |
| return {}; |
| } |
| std::vector<OpType> result; |
| if (op_list == std::vector<std::string>{"all"}) { |
| for (auto entry : str2op) { |
| result.push_back(entry.second); |
| } |
| } else { |
| for (auto op_string : op_list) { |
| result.push_back(str2op.at(op_string)); |
| } |
| } |
| return result; |
| } |
| |
| } // namespace |
| |
| void read_config() { |
| // default configuration |
| torch_ucc_config.blocking_wait.fill(false); |
| torch_ucc_config.use_future = true; |
| torch_ucc_config.shared_comm = false; |
| torch_ucc_config.use_allgatherv = false; |
| torch_ucc_config.enable_health_check = false; |
| torch_ucc_config.enable_comms_logger = false; |
| |
| // read all torch_ucc env. variables and update the map |
| char* env; |
| for (auto& torch_ucc_env : torch_ucc_envs_map) { |
| env = std::getenv(torch_ucc_env.first.c_str()); |
| if (env) { |
| torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); |
| } |
| } |
| |
| auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT"); |
| for (auto op : parse_blocking_wait(blocking_wait_str)) { |
| torch_ucc_config.blocking_wait[(std::uint8_t)op] = true; |
| } |
| // barrier is always blocking |
| torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; |
| |
| torch_ucc_config.use_future = |
| std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE")); |
| torch_ucc_config.shared_comm = |
| std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM")); |
| torch_ucc_config.use_allgatherv = |
| std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV")); |
| torch_ucc_config.enable_health_check = |
| std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK")); |
| torch_ucc_config.enable_comms_logger = |
| std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER")); |
| } |
| |
| void check_device(c10::Device dev1, c10::Device dev2) { |
| if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) { |
| throw std::runtime_error("ProcessGroupUCC multidevice is not supported"); |
| } |
| } |
| |
| void check_tensor(const std::vector<at::Tensor>& tensors) { |
| if (tensors.size() != 1) { |
| throw std::runtime_error( |
| "ProcessGroupUCC takes 1 tensor. Got " + |
| std::to_string(tensors.size()) + ". "); |
| } |
| if (!tensors[0].is_contiguous()) { |
| throw std::runtime_error( |
| "ProcessGroupUCC input tensor has to be contiguous"); |
| } |
| if (tensors[0].is_sparse()) { |
| throw std::runtime_error("ProcessGroupUCC input tensor has to be dense"); |
| } |
| // TODO: check cuda case |
| } |
| |
| ProcessGroupUCC::WorkUCC::~WorkUCC() { |
| #ifdef USE_CUDA |
| if (fence && ep) { |
| std::lock_guard<std::mutex> lock(ep->event_pool_mutex); |
| ep->event_pool.push(std::move(fence)); |
| } |
| #endif |
| } |
| |
| void ProcessGroupUCC::WorkUCC::setException() { |
| if (exception() || !entry_) { |
| return; |
| } |
| exception_ = entry_->eptr_; |
| } |
| |
| void ProcessGroupUCC::WorkUCC::setAndThrowException() { |
| setException(); |
| if (exception()) { |
| std::rethrow_exception(exception()); |
| } |
| } |
| |
| bool ProcessGroupUCC::WorkUCC::isCompleted() { |
| if (!entry_) { |
| return true; |
| } |
| setException(); |
| // status_ <= 0 to avoid listing all possible status codes. The main thread |
| // needs to be unblocked when UCC (in progress thread) returns success (== 0) |
| // or any error code (< 0). |
| return exception() || entry_->status_ <= 0; |
| } |
| |
| bool ProcessGroupUCC::WorkUCC::isSuccess() const { |
| if (!entry_) { |
| return true; |
| } |
| return !exception() && entry_->status_ == 0; |
| } |
| |
| bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { |
| if (torch_ucc_config.enable_comms_logger && logger_) { |
| logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_); |
| } |
| #ifdef USE_CUDA |
| if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { |
| // block user stream |
| setAndThrowException(); |
| fence->block(at::cuda::getCurrentCUDAStream()); |
| return true; |
| } |
| #endif |
| // wait for complete. For blocking case, the main thread will be blocked in |
| // this loop until the progress thread changes the status of this request. |
| // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The |
| // main thread will throw out the exception then. There is no "abort" |
| // function in UCC currently. |
| while (!isCompleted()) |
| ; |
| setAndThrowException(); |
| // manually call profiling end callbacks if they are set, |
| // since progress thread does not own WorkUCC |
| if (Work::recordFunctionEndCallback_) { |
| Work::recordFunctionEndCallback_(); |
| Work::recordFunctionEndCallback_ = nullptr; |
| } |
| return true; |
| } |
| |
| c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() { |
| return future_; |
| } |
| |
| int ProcessGroupUCC::WorkUCC::sourceRank() const { |
| if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) { |
| // Throw an error |
| return Work::sourceRank(); |
| } |
| return sourceRank_; |
| } |
| |
| std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() { |
| return *outputs_; |
| } |
| |
| void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) { |
| ucc_status_t status = UCC_OK; |
| |
| if (request_ != nullptr) { |
| status = request_->status; |
| comm_->free_request(request_); |
| } |
| if (eptr) { |
| eptr_ = eptr; |
| } else { |
| status_ = status; |
| } |
| if (future_) { |
| if (eptr) { |
| future_->setError(eptr); |
| } else { |
| future_->markCompleted( |
| c10::IValue(data ? data->dst : std::vector<at::Tensor>())); |
| } |
| } |
| } |
| |
| Comm::Comm( |
| const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_, |
| std::shared_ptr<torch_ucc_oob_coll_info_t> oob_, |
| c10::Device dev, |
| bool is_health_check) |
| : logger(logger_), |
| oob(oob_), |
| ucc_comm(oob, logger), |
| finalize_phase( |
| is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE), |
| cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) { |
| if (dev.is_cuda()) { |
| cuda_device_index = dev.index(); |
| } |
| stop_progress_loop = false; |
| collective_inprogress = false; |
| progress_thread = std::thread(&Comm::progress_loop, this); |
| #ifdef _GNU_SOURCE |
| pthread_setname_np(progress_thread.native_handle(), "ucc-progress"); |
| #endif |
| } |
| |
| Comm::~Comm() { |
| std::unique_lock<std::mutex> lock(mutex); |
| queue_consume_cv.wait( |
| lock, [&] { return progress_queue.empty() && !collective_inprogress; }); |
| stop_progress_loop = true; |
| lock.unlock(); |
| queue_produce_cv.notify_all(); |
| progress_thread.join(); |
| } |
| |
| std::shared_ptr<Comm> Comm::get_comm( |
| uint32_t& id, |
| c10::Device dev, |
| std::shared_ptr<torch_ucc_oob_coll_info_t> oob, |
| const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, |
| bool is_health_check) { |
| static std::mutex m; |
| static std::weak_ptr<Comm> comm; |
| static uint32_t comm_id; |
| |
| std::lock_guard<std::mutex> lock(m); |
| id = comm_id; |
| |
| std::string group_id = "group_id"; |
| if (is_health_check) { |
| group_id = c10::str(dev.type()) + "/" + group_id; |
| } |
| |
| std::vector<uint8_t> remote_comm_id; |
| oob->store->deleteKey(group_id + std::to_string(0)); |
| if (oob->rank != 0) { |
| std::vector<uint8_t> val = std::vector<uint8_t>( |
| reinterpret_cast<uint8_t*>(&id), |
| reinterpret_cast<uint8_t*>(&id) + sizeof(id)); |
| oob->store->set(group_id + std::to_string(oob->rank), val); |
| } else { |
| for (int i = 1; i < oob->size; i++) { |
| remote_comm_id = oob->store->get(group_id + std::to_string(i)); |
| oob->store->deleteKey(group_id + std::to_string(i)); |
| // Find the highest id. |
| id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data()))); |
| } |
| std::vector<uint8_t> val = std::vector<uint8_t>( |
| reinterpret_cast<uint8_t*>(&id), |
| reinterpret_cast<uint8_t*>(&id) + sizeof(id)); |
| oob->store->set(group_id + std::to_string(oob->rank), val); |
| } |
| remote_comm_id = oob->store->get(group_id + std::to_string(0)); |
| oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data())); |
| // Prepare comm_id (static variable) to the next id. |
| comm_id = oob->comm_id + 1; |
| |
| if (torch_ucc_config.shared_comm) { |
| std::shared_ptr<Comm> shared_comm = comm.lock(); |
| if (!shared_comm) { |
| shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check); |
| comm = shared_comm; |
| } else { |
| if (dev.is_cuda() && !is_health_check) { |
| if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && |
| (shared_comm->cuda_device_index != dev.index())) { |
| TORCH_UCC_LOG_ERROR( |
| is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT, |
| "ucc communicator was initialized with different cuda device," |
| "multi device is not supported"); |
| throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
| } |
| shared_comm->cuda_device_index = dev.index(); |
| } |
| } |
| return shared_comm; |
| } else { |
| return std::make_shared<Comm>(logger, oob, dev, is_health_check); |
| } |
| } |
| |
| void Comm::ucc_create_team( |
| ucc_team_h& team, |
| std::shared_ptr<torch_ucc_oob_coll_info_t> oob) { |
| ucc_status_t st; |
| ucc_team_params_t team_params; |
| team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | |
| UCC_TEAM_PARAM_FIELD_OOB; |
| team_params.oob.allgather = oob_allgather; |
| team_params.oob.req_test = oob_allgather_test; |
| team_params.oob.req_free = oob_allgather_free; |
| team_params.oob.coll_info = oob.get(); |
| team_params.oob.n_oob_eps = oob->size; |
| team_params.oob.oob_ep = oob->rank; |
| team_params.ep = oob->rank; |
| team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; |
| TORCH_UCC_CHECK( |
| ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team), |
| "failed to post team create"); |
| do { |
| st = ucc_team_create_test(team); |
| ucc_context_progress(ucc_comm.context); |
| } while (st == UCC_INPROGRESS); |
| TORCH_UCC_CHECK(st, "failed to create UCC team"); |
| } |
| |
| void Comm::ucc_destroy_team(ucc_team_h& team) { |
| std::unique_lock<std::mutex> lock(mutex); |
| queue_consume_cv.wait( |
| lock, [&] { return progress_queue.empty() && !collective_inprogress; }); |
| |
| ucc_status_t status; |
| while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) { |
| if (UCC_OK != status) { |
| TORCH_UCC_LOG_ERROR( |
| finalize_phase, |
| c10::str("ucc team destroy error: ", ucc_status_string(status))); |
| break; |
| } |
| } |
| |
| lock.unlock(); |
| } |
| |
| void Comm::enqueue_collective( |
| std::unique_ptr<ProcessGroupUCC::WorkData> data, |
| c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
| ucc_coll_args_t& coll, |
| ucc_team_h team) { |
| ucc_coll_req_h request; |
| TORCH_UCC_CHECK( |
| ucc_collective_init(&coll, &request, team), "failed to init collective"); |
| TORCH_UCC_CHECK_REQUEST( |
| request, ucc_collective_post(request), "failed to post collective"); |
| |
| auto entry = |
| std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request); |
| entry->data = std::move(data); |
| entry->future_ = work->getFuture(); |
| work->entry_ = entry; |
| std::unique_lock<std::mutex> lock(mutex); |
| progress_queue.push_back(entry); |
| lock.unlock(); |
| queue_produce_cv.notify_one(); |
| } |
| |
| #ifdef USE_CUDA |
| void Comm::enqueue_cuda_collective( |
| std::unique_ptr<ProcessGroupUCC::WorkData> data, |
| c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
| ucc_coll_args_t& coll, |
| ucc_team_h team, |
| ucc_ee_h ee) { |
| ucc_coll_req_h request; |
| TORCH_UCC_CHECK( |
| ucc_collective_init(&coll, &request, team), |
| "failed to init cuda collective"); |
| ucc_ev_t comp_ev, *post_ev; |
| comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; |
| comp_ev.ev_context = nullptr; |
| comp_ev.ev_context_size = 0; |
| comp_ev.req = request; |
| TORCH_UCC_CHECK_REQUEST( |
| request, |
| ucc_collective_triggered_post(ee, &comp_ev), |
| "failed to post triggered collective"); |
| ucc_status_t st = ucc_ee_get_event(ee, &post_ev); |
| TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); |
| ucc_ee_ack_event(ee, post_ev); |
| auto entry = |
| std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request); |
| entry->data = std::move(data); |
| work->entry_ = entry; |
| std::unique_lock<std::mutex> lock(mutex); |
| progress_queue.push_back(entry); |
| lock.unlock(); |
| queue_produce_cv.notify_one(); |
| } |
| #endif |
| |
| void Comm::progress_loop() { |
| std::unique_lock<std::mutex> lock(mutex); |
| #ifdef USE_CUDA |
| bool device_set = false; |
| #endif |
| while (!stop_progress_loop) { |
| if (progress_queue.empty()) { |
| queue_produce_cv.wait(lock); |
| continue; |
| } |
| collective_inprogress = true; |
| auto work = progress_queue.front(); |
| progress_queue.pop_front(); |
| lock.unlock(); |
| #ifdef USE_CUDA |
| if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { |
| c10::cuda::set_device(cuda_device_index); |
| device_set = true; |
| } |
| #endif |
| std::exception_ptr eptr; |
| try { |
| while (work->request_->status > 0) { |
| ucc_comm.progress(); |
| } |
| if (work->request_->status < 0) { |
| eptr = std::make_exception_ptr( |
| std::runtime_error(ucc_status_string(work->request_->status))); |
| std::string err_log = c10::str( |
| "Failed to progress communication", // TODO: report exact op type or |
| // id? |
| ucc_status_string(work->request_->status)); |
| TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log); |
| } |
| } catch (...) { |
| eptr = std::current_exception(); |
| } |
| work->finalize(eptr); |
| work = nullptr; |
| collective_inprogress = false; |
| queue_consume_cv.notify_one(); |
| lock.lock(); |
| } |
| } |
| |
| ProcessGroupUCC::ProcessGroupUCC( |
| const c10::intrusive_ptr<Store>& store, |
| int rank, |
| int size, |
| std::chrono::duration<float> timeout) |
| : Backend(rank, size), timeout_(timeout) { |
| c10::call_once(torch_ucc_config.flag, read_config); |
| oob = std::make_shared<torch_ucc_oob_coll_info_t>(); |
| oob->rank = rank; |
| oob->size = size; |
| oob->store = store; |
| comm = nullptr; |
| cuda_ee = nullptr; |
| static uint32_t id = 0; |
| uint32_t pg_id = id++; |
| |
| logger = c10::make_intrusive<ProcessGroupUCCLogger>( |
| c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"), |
| TORCH_UCC_INIT); |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_INIT, |
| c10::str( |
| "Created ProcessGroupUCC with ", |
| size, |
| " ranks, with timeout ", |
| timeout_.count(), |
| " secs")); |
| std::string envs = ""; |
| for (auto& torch_ucc_env : torch_ucc_envs_map) { |
| envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second); |
| } |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_INIT, |
| c10::str( |
| "Successfully read and set ProcessGroupUCC env. variables as followings", |
| envs)); |
| |
| if (torch_ucc_config.enable_health_check) { |
| // Perform health check by initializing dummy communicators and destroying |
| // them. This will help indicate any UCC/UCX-related issues prior to the |
| // first collective. Run it in a separate thread and wait on CV to handle |
| // timeouts so that if there are hangs, the main thread can still run |
| // correctly. |
| runHealthCheck(); |
| } |
| if (torch_ucc_config.enable_comms_logger) { |
| logger->initCommsTracer(); |
| } |
| } |
| |
| ProcessGroupUCC::~ProcessGroupUCC() { |
| if (torch_ucc_config.enable_comms_logger) { |
| logger->flushComms(this->getRank(), this->getSize()); |
| } |
| if (comm) { |
| logger->setPhase(TORCH_UCC_FINALIZE); |
| comm->ucc_destroy_team(team); |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_FINALIZE, "Successfully destroyed UCC library"); |
| try { |
| if (cuda_ee) { |
| ucc_ee_destroy(cuda_ee); |
| } |
| } catch (std::exception& ex) { |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_FINALIZE, |
| c10::str( |
| "(~ProcessGroupUCC) Caught error in Store Operation .. ", |
| "[", |
| ex.what(), |
| "]")); |
| } |
| comm = nullptr; |
| } |
| } |
| |
| #ifdef USE_CUDA |
| // Return CUDA device with ordinal given by input rank. |
| c10::Device getCUDADeviceForRank(int rank) { |
| TORCH_CHECK(rank >= 0, "Invalid rank ", rank); |
| auto numGPUs = at::cuda::getNumGPUs(); |
| auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs); |
| return c10::Device(c10::DeviceType::CUDA, deviceIdx); |
| } |
| #endif |
| |
| void ProcessGroupUCC::runHealthCheck() { |
| // Run health check in a separate thread and wait on CV to handle timeouts. |
| // This design allows us to handle hangs. |
| |
| // When size_ is 1, there is no need to do any communication at all. |
| if (size_ == 1) |
| return; |
| |
| struct HealthCheckData { |
| std::mutex healthCheckMutex; |
| std::condition_variable healthCheckCv; |
| bool uccHealthCheckSuccess = false; |
| std::exception_ptr healthCheckException; |
| } healthCheckData; |
| |
| auto t = std::thread([&healthCheckData, this]() { |
| std::list<c10::Device> devices{c10::kCPU}; |
| #ifdef USE_CUDA |
| c10::cuda::OptionalCUDAGuard gpuGuard; |
| if (at::cuda::is_available()) { |
| devices.emplace_front(getCUDADeviceForRank(rank_)); |
| } |
| #endif |
| for (auto device : devices) { |
| bool is_last_device = (device == devices.back()); |
| try { |
| auto oob = std::make_shared<torch_ucc_oob_coll_info_t>(); |
| oob->rank = this->oob->rank; |
| oob->size = this->oob->size; |
| oob->store = this->oob->store; |
| ucc_team_h team = nullptr; |
| uint32_t comm_id; |
| #ifdef USE_CUDA |
| if (device.is_cuda()) { |
| gpuGuard.set_index(device.index()); |
| } |
| #endif |
| auto comm = Comm::get_comm(comm_id, device, oob, logger, true); |
| comm->ucc_create_team(team, oob); |
| comm->ucc_destroy_team(team); |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_HEALTH_CHECK, |
| c10::str( |
| "UCC library health check succeed for device ", |
| c10::DeviceTypeName(device.type()))); |
| // Mark ucc health check as complete. |
| if (is_last_device) { |
| std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex); |
| healthCheckData.uccHealthCheckSuccess = true; |
| } |
| |
| comm = nullptr; |
| oob = nullptr; |
| // Notify main thread the health check is complete. |
| if (is_last_device) { |
| healthCheckData.healthCheckCv.notify_one(); |
| } |
| } catch (const std::exception& e) { |
| // Populate exception ptr. |
| healthCheckData.healthCheckException = std::current_exception(); |
| // Unblock waiting main thread which will report exception. |
| healthCheckData.healthCheckCv.notify_one(); |
| } // Unknown exceptions will just cause the program to terminate. |
| } |
| }); |
| // We don't need to join the thread, just need to verify health check via the |
| // CV. Hence we detach the thread here. |
| t.detach(); // NOLINT |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_HEALTH_CHECK, |
| c10::str( |
| "will wait up to ", |
| timeout_.count(), |
| " msec for UCC health check to complete.")); |
| std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex); |
| healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() { |
| return healthCheckData.uccHealthCheckSuccess; |
| }); |
| |
| if (healthCheckData.healthCheckException) { |
| std::rethrow_exception(healthCheckData.healthCheckException); |
| } |
| // If there is no exception, the likely culprit is a timeout/hang |
| TORCH_CHECK( |
| healthCheckData.uccHealthCheckSuccess, |
| "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ", |
| rank_); |
| } |
| |
| void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) { |
| args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; |
| args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT; |
| args.timeout = timeout_.count(); |
| } |
| |
| #ifdef USE_CUDA |
| std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() { |
| std::unique_ptr<at::cuda::CUDAEvent> ev; |
| std::lock_guard<std::mutex> lock(ep.event_pool_mutex); |
| if (ep.event_pool.empty()) { |
| ev = std::make_unique<at::cuda::CUDAEvent>(); |
| } else { |
| ev = std::move(ep.event_pool.front()); |
| ep.event_pool.pop(); |
| } |
| return ev; |
| } |
| #endif |
| |
| template <typename PreProcess, typename PostProcess> |
| c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post( |
| OpType opType, |
| PreProcess preproc, |
| PostProcess postproc, |
| ucc_coll_args_t& coll, |
| std::unique_ptr<ProcessGroupUCC::WorkData> data, |
| c10::Device dev, |
| std::vector<at::Tensor>& inputTensors, |
| std::vector<at::Tensor>& outputTensors, |
| const char* prof_title) { |
| seq_++; |
| set_timeout(coll); |
| auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>( |
| opType, seq_, prof_title, inputTensors, logger); |
| |
| if (opType == OpType::RECV) { |
| work->sourceRank_ = coll.root; |
| } |
| |
| RECORD_COMMS_TRACE( |
| logger->trace_generator, |
| work, |
| opType, |
| this->getRank(), |
| this->getSize(), |
| inputTensors, |
| outputTensors); |
| |
| // Store references to outputs to be used by result |
| work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors); |
| switch (dev.type()) { |
| case c10::DeviceType::CPU: { |
| if (torch_ucc_config.use_future) { |
| work->future_ = c10::make_intrusive<at::ivalue::Future>( |
| c10::ListType::create(c10::TensorType::get())); |
| } |
| preproc(); |
| comm->enqueue_collective(std::move(data), work, coll, team); |
| postproc(); |
| return work; |
| } |
| #ifdef USE_CUDA |
| case c10::DeviceType::CUDA: { |
| auto cuda_ev = getPooledEvent(); |
| cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); |
| cuda_ev->block(*stream); |
| at::cuda::CUDAStreamGuard guard(*stream); |
| preproc(); |
| comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee); |
| postproc(); |
| cuda_ev->record(*stream); |
| work->fence = std::move(cuda_ev); |
| work->ep = &ep; |
| if (torch_ucc_config.use_future) { |
| c10::cuda::CUDAMultiStreamGuard streamGuard(*stream); |
| std::vector<c10::Device> devList{dev}; |
| work->future_ = c10::make_intrusive<at::ivalue::Future>( |
| c10::ListType::create(c10::TensorType::get()), devList); |
| // Add a callback that runs profiling end callbacks |
| if (work->recordFunctionEndCallback_) { |
| work->future_->addCallback([work](at::ivalue::Future& /* unused */) { |
| work->recordFunctionEndCallback_(); |
| }); |
| } |
| |
| work->future_->markCompleted(c10::IValue(outputTensors)); |
| } |
| return work; |
| } |
| #endif // #ifdef USE_CUDA |
| default: { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str())); |
| throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
| } |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& /* unused */) { |
| auto& tensor = inputTensors[0]; |
| check_device(tensor.device(), outputTensors[0][0].device()); |
| initComm(tensor.device()); |
| |
| if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) { |
| AllgathervWorkData* data = new AllgathervWorkData(size_); |
| for (int i = 0; i < size_; i++) { |
| data->recv_lengths[i] = tensor.element_size() * tensor.numel(); |
| data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr(); |
| } |
| ucc_coll_args_t coll; |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = |
| UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
| coll.coll_type = UCC_COLL_TYPE_ALLGATHERV; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.element_size() * tensor.numel(); |
| coll.src.info.datatype = UCC_DT_UINT8; |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.dst.info_v.buffer = nullptr; |
| coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
| coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
| coll.dst.info_v.datatype = UCC_DT_UINT8; |
| coll.dst.info_v.mem_type = |
| to_ucc_memType(outputTensors[0][0].device().type()); |
| SAVE_TENSORS(inputTensors, data->src); |
| SAVE_TENSORS(outputTensors[0], data->dst); |
| |
| return collective_post( |
| OpType::ALLGATHER, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| inputTensors, |
| outputTensors[0], |
| "ucc:all_gather"); |
| } else { |
| WorkData* data = new WorkData(); |
| std::vector<at::Tensor> flat_output(outputTensors.size()); |
| for (size_t i = 0; i < outputTensors.size(); i++) { |
| TORCH_CHECK( |
| outputTensors[i].size() == outputTensors.size() * size_, |
| "Tensor output list is not valid for the number of participants"); |
| flat_output[i] = c10d::newLikeFlat(outputTensors, i); |
| } |
| SAVE_TENSORS(flat_output, data->flat); |
| ucc_coll_args_t coll; |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_ALLGATHER; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = to_ucc_dType(tensor); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.dst.info.buffer = flat_output[0].data_ptr(); |
| coll.dst.info.count = flat_output[0].numel(); |
| coll.dst.info.datatype = to_ucc_dType(flat_output[0]); |
| coll.dst.info.mem_type = |
| to_ucc_memType(outputTensors[0][0].device().type()); |
| |
| auto copy_from_flat = [&] { |
| bool asyncCopy = false; |
| #ifdef USE_CUDA |
| bool isCuda = outputTensors[0][0].device().is_cuda(); |
| ; |
| #endif |
| for (size_t i = 0; i < outputTensors.size(); i++) { |
| auto inumel = inputTensors[i].numel(); |
| for (size_t j = 0; j < outputTensors[i].size(); j++) { |
| TORCH_CHECK( |
| (outputTensors[i][j].numel() == inumel), |
| "Tensor operand counts must be same"); |
| #ifdef USE_CUDA |
| if (isCuda) { |
| c10::cuda::CUDACachingAllocator::recordStream( |
| outputTensors[i][j].storage().data_ptr(), (*stream)); |
| asyncCopy = true; |
| } |
| #endif |
| outputTensors[i][j].copy_(flat_output[i][j], asyncCopy); |
| } |
| } |
| }; |
| return collective_post( |
| OpType::ALLGATHER, |
| []() {}, |
| copy_from_flat, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| inputTensors, |
| outputTensors[0], |
| "ucc:all_gather"); |
| } |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| const AllgatherOptions& opts) { |
| check_tensor({outputTensor}); |
| check_tensor({inputTensor}); |
| initComm(outputTensor.device()); |
| |
| WorkData* data = new WorkData(); |
| |
| ucc_coll_args_t coll; |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_ALLGATHER; |
| coll.src.info.buffer = inputTensor.data_ptr(); |
| coll.src.info.count = inputTensor.numel(); |
| coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type()); |
| coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); |
| coll.dst.info.buffer = outputTensor.data_ptr(); |
| coll.dst.info.count = outputTensor.numel(); |
| coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type()); |
| coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); |
| |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| SAVE_TENSORS(inputTensors, data->src); |
| SAVE_TENSORS(outputTensors, data->dst); |
| |
| return collective_post( |
| OpType::_ALLGATHER_BASE, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| outputTensor.device(), |
| inputTensors, |
| outputTensors, |
| "ucc:allgather_base"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceOptions& opts) { |
| check_tensor(tensors); |
| auto& tensor = tensors[0]; |
| initComm(tensor.device()); |
| WorkData* data = new WorkData(); |
| |
| ucc_coll_args_t coll; |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; |
| coll.coll_type = UCC_COLL_TYPE_ALLREDUCE; |
| coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type()); |
| coll.src.info.buffer = nullptr; |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = to_ucc_dType(tensor); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.dst.info.buffer = tensor.data_ptr(); |
| coll.dst.info.count = tensor.numel(); |
| coll.dst.info.datatype = to_ucc_dType(tensor); |
| coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
| SAVE_TENSORS(tensors, data->dst); |
| return collective_post( |
| OpType::ALLREDUCE, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| tensors, |
| tensors, |
| "ucc:all_reduce"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced( |
| std::vector<at::Tensor>& /* unused */, |
| const AllreduceCoalescedOptions& /* unused */) { |
| throw std::runtime_error( |
| "ProcessGroupUCC does not support allreduce_coalesced"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllToAllOptions& /* unused */) { |
| auto device = outputTensors[0].device(); |
| for (const auto r : c10::irange(outputTensors.size())) { |
| TORCH_CHECK( |
| device == outputTensors[r].device() && |
| device == inputTensors[r].device(), |
| "Tensors must be on the same device") |
| } |
| |
| initComm(device); |
| ucc_coll_args_t coll; |
| AlltoallWorkData* data; |
| data = new AlltoallWorkData(size_); |
| |
| /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as |
| follow. |
| 1. store addresses of each tensor directly in displacements, keep buffer |
| to nullptr, i.e., 0 |
| 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size |
| calculation in UCC layer |
| 3. post Alltoallv |
| */ |
| for (const auto i : c10::irange(size_)) { |
| data->send_lengths[i] = |
| (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel()); |
| data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr(); |
| data->recv_lengths[i] = |
| (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel()); |
| data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr(); |
| } |
| |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = |
| UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
| coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; |
| coll.src.info_v.buffer = 0; |
| coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
| coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
| coll.src.info_v.datatype = UCC_DT_UINT8; |
| coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type()); |
| coll.dst.info_v.buffer = 0; |
| coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
| coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
| coll.dst.info_v.datatype = UCC_DT_UINT8; |
| coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type()); |
| |
| SAVE_TENSORS(inputTensors, data->src); |
| SAVE_TENSORS(outputTensors, data->dst); |
| |
| return collective_post( |
| OpType::ALLTOALL, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| device, |
| inputTensors, |
| outputTensors, |
| "ucc:alltoall"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| std::vector<int64_t>& outputSplitSizes, |
| std::vector<int64_t>& inputSplitSizes, |
| const AllToAllOptions& /* unused */) { |
| check_device(inputTensor.device(), outputTensor.device()); |
| initComm(inputTensor.device()); |
| ucc_coll_args_t coll; |
| AlltoallWorkData* data; |
| |
| if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) { |
| data = new AlltoallWorkData(0); |
| TORCH_CHECK( |
| (outputTensor.size(0) % size_ == 0) && |
| (inputTensor.size(0) % size_ == 0), |
| "Tensor's dim 0 does not divide equally across group size"); |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_ALLTOALL; |
| coll.src.info.buffer = inputTensor.data_ptr(); |
| coll.src.info.count = inputTensor.element_size() * inputTensor.numel(); |
| coll.src.info.datatype = UCC_DT_UINT8; |
| coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); |
| coll.dst.info.buffer = outputTensor.data_ptr(); |
| coll.dst.info.count = outputTensor.element_size() * outputTensor.numel(); |
| coll.dst.info.datatype = UCC_DT_UINT8; |
| coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); |
| coll.flags = 0; |
| } else { |
| data = new AlltoallWorkData(size_); |
| c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); |
| c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); |
| computeLengthsAndOffsets( |
| outputSplitSizes, |
| outputTensor, |
| &data->recv_lengths, |
| &data->recv_offsets); |
| computeLengthsAndOffsets( |
| inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets); |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; |
| coll.src.info_v.buffer = inputTensor.data_ptr(); |
| coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
| coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
| coll.src.info_v.datatype = to_ucc_dType(inputTensor); |
| coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type()); |
| coll.dst.info_v.buffer = outputTensor.data_ptr(); |
| coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
| coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
| coll.dst.info_v.datatype = to_ucc_dType(outputTensor); |
| coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type()); |
| coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | |
| UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT | |
| UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
| |
| if (torch_ucc_config.enable_comms_logger) { |
| logger->trace_generator->recordOptionalInfo( |
| outputSplitSizes, inputSplitSizes); |
| } |
| } |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| SAVE_TENSORS(inputTensors, data->src); |
| SAVE_TENSORS(outputTensors, data->dst); |
| |
| return collective_post( |
| OpType::ALLTOALL_BASE, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| inputTensor.device(), |
| inputTensors, |
| outputTensors, |
| "ucc:alltoall"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) { |
| c10::Device device = c10::Device(c10::DeviceType::CPU); |
| #ifdef USE_CUDA |
| auto numGPUs = c10::cuda::device_count(); |
| if (!opts.device_ids.empty()) { |
| device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front()); |
| } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) { |
| device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index); |
| } else if (numGPUs > 0) { |
| int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device()); |
| // if current device is 0, likely the device is not set, use the best guess |
| if (0 == (int)deviceIdx) { |
| deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs); |
| } |
| TORCH_UCC_LOG_INFO( |
| TORCH_UCC_COLL_POST, |
| c10::str( |
| "post barrier before specifying any GPU while there are ", |
| numGPUs, |
| " GPUs available. ", |
| "Not clear if GPU barrier is required, using GPU ", |
| (int)deviceIdx, |
| " to perform barrier. ", |
| "Specify device_ids option in barrier() to force ", |
| "use of a particular device")); |
| device = c10::Device(c10::DeviceType::CUDA, deviceIdx); |
| } |
| #endif |
| initComm(device); |
| |
| ucc_coll_args_t coll; |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_BARRIER; |
| auto dummy_tensor = std::vector<at::Tensor>(); |
| return collective_post( |
| OpType::BARRIER, |
| []() {}, |
| []() {}, |
| coll, |
| nullptr, |
| device, |
| dummy_tensor, |
| dummy_tensor, |
| "ucc:barrier"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast( |
| std::vector<at::Tensor>& tensors, |
| const BroadcastOptions& opts) { |
| check_tensor(tensors); |
| auto& tensor = tensors[0]; |
| initComm(tensor.device()); |
| WorkData* data = new WorkData(); |
| |
| ucc_coll_args_t coll; |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_BCAST; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = to_ucc_dType(tensor); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.root = opts.rootRank; |
| SAVE_TENSORS(tensors, data->dst); |
| |
| if (torch_ucc_config.enable_comms_logger) { |
| logger->trace_generator->recordOptionalInfo(opts.rootRank); |
| } |
| |
| return collective_post( |
| OpType::BROADCAST, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| tensors, |
| tensors, |
| "ucc:broadcast"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::gather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const GatherOptions& opts) { |
| std::vector<at::Tensor> outputs; |
| auto& input = inputTensors[0]; |
| initComm(input.device()); |
| |
| AllgathervWorkData* data = new AllgathervWorkData(size_); |
| ucc_coll_args_t coll; |
| coll.root = opts.rootRank; |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = |
| UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
| coll.coll_type = UCC_COLL_TYPE_GATHERV; |
| |
| /* for non-root ranks, only src is valid */ |
| coll.src.info.buffer = input.data_ptr(); |
| coll.src.info.count = (uint64_t)(input.element_size() * input.numel()); |
| coll.src.info.datatype = UCC_DT_UINT8; |
| coll.src.info.mem_type = to_ucc_memType(input.device().type()); |
| |
| if (getRank() == opts.rootRank) { |
| if (outputTensors.size() != 1) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, |
| c10::str( |
| "gather requires a single-element output list containing a list with ", |
| getSize(), |
| " tensors.")); |
| } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, |
| c10::str( |
| "Incorrect output list size ", |
| outputTensors[0].size(), |
| ". Output list size should be ", |
| getSize(), |
| ", same as size of the process group.")); |
| } |
| outputs = outputTensors[0]; |
| |
| for (int i = 0; i < size_; i++) { |
| data->recv_lengths[i] = |
| (uint64_t)(outputs[i].element_size() * outputs[i].numel()); |
| data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr(); |
| } |
| /* use gatherv and store non-contiguous addresses in displacements to avoid |
| * flatten outputTensors */ |
| coll.dst.info_v.buffer = nullptr; |
| coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
| coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
| coll.dst.info_v.datatype = UCC_DT_UINT8; |
| coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type()); |
| |
| SAVE_TENSORS(outputs, data->dst); |
| } else { |
| // for non-root ranks, outputTensors should be an empty list |
| if (outputTensors.size() != 0) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, "requires empty output on non-root"); |
| } |
| outputs = {}; |
| // append a empty tensor to the list to be used by future mark |
| outputs.emplace_back(); |
| } |
| |
| SAVE_TENSORS(inputTensors, data->src); |
| |
| return collective_post( |
| OpType::GATHER, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| input.device(), |
| inputTensors, |
| outputs, |
| "ucc:gather"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::reduce( |
| std::vector<at::Tensor>& tensors, |
| const ReduceOptions& opts) { |
| check_tensor(tensors); |
| auto& tensor = tensors[0]; |
| initComm(tensor.device()); |
| WorkData* data = new WorkData(); |
| |
| ucc_coll_args_t coll; |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; |
| coll.coll_type = UCC_COLL_TYPE_REDUCE; |
| coll.op = ucc_op_map.at(opts.reduceOp); |
| coll.root = opts.rootRank; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.dst.info.buffer = tensor.data_ptr(); |
| coll.dst.info.count = tensor.numel(); |
| coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); |
| coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
| SAVE_TENSORS(tensors, data->dst); |
| return collective_post( |
| OpType::REDUCE, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| tensors, |
| tensors, |
| "ucc:reduce"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ReduceScatterOptions& opts) { |
| TORCH_CHECK( |
| (outputTensors.size() == inputTensors.size()), |
| "Tensor input/output list for reduce_scatter must have same size"); |
| check_tensor(outputTensors); |
| check_device(inputTensors[0][0].device(), outputTensors[0].device()); |
| initComm(inputTensors[0][0].device()); |
| auto data = std::make_unique<WorkData>(); |
| std::vector<at::Tensor> flat_input(inputTensors.size()); |
| for (size_t i = 0; i < inputTensors.size(); i++) { |
| TORCH_CHECK( |
| inputTensors[i].size() == inputTensors.size() * size_, |
| "Tensor input list is not valid for the number of participants"); |
| flat_input[i] = c10d::newLikeFlat(inputTensors, i); |
| } |
| SAVE_TENSORS(flat_input, data->flat); |
| check_tensor(flat_input); |
| ucc_coll_args_t coll; |
| coll.mask = 0; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; |
| coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type()); |
| |
| coll.src.info.buffer = flat_input[0].data_ptr(); |
| coll.src.info.count = flat_input[0].numel(); |
| coll.src.info.datatype = to_ucc_dType(flat_input[0]); |
| coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type()); |
| coll.dst.info.buffer = outputTensors[0].data_ptr(); |
| coll.dst.info.count = outputTensors[0].numel(); |
| coll.dst.info.datatype = to_ucc_dType(outputTensors[0]); |
| coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type()); |
| |
| SAVE_TENSORS(inputTensors[0], data->src); |
| SAVE_TENSORS(outputTensors, data->dst); |
| |
| auto copy_to_flat = [&] { |
| bool asyncCopy = false; |
| auto isize = inputTensors.size(); |
| #ifdef USE_CUDA |
| bool isCuda = inputTensors[0][0].device().is_cuda(); |
| #endif |
| for (size_t i = 0; i < isize; i++) { |
| auto onumel = outputTensors[i].numel(); |
| for (size_t j = 0; j < inputTensors[i].size(); j++) { |
| TORCH_CHECK( |
| (inputTensors[i][j].numel() == onumel), |
| "Tensor operand counts must be same"); |
| #ifdef USE_CUDA |
| if (isCuda) { |
| c10::cuda::CUDACachingAllocator::recordStream( |
| inputTensors[i][j].storage().data_ptr(), (*stream)); |
| asyncCopy = true; |
| } |
| #endif |
| flat_input[i][j].copy_(inputTensors[i][j], asyncCopy); |
| } |
| } |
| }; |
| |
| return collective_post( |
| OpType::REDUCE_SCATTER, |
| copy_to_flat, |
| []() {}, |
| coll, |
| std::move(data), |
| inputTensors[0][0].device(), |
| inputTensors[0], |
| outputTensors, |
| "ucc:reduce_scatter"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ScatterOptions& opts) { |
| auto& tensor = outputTensors[0]; |
| initComm(tensor.device()); |
| |
| ScattervWorkData* data = new ScattervWorkData(size_); |
| ucc_coll_args_t coll; |
| coll.root = opts.rootRank; |
| coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
| coll.flags = |
| UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
| coll.coll_type = UCC_COLL_TYPE_SCATTERV; |
| |
| if (getRank() == opts.rootRank) { |
| /* src is only valid at non-root rank */ |
| if (inputTensors.size() != 1) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, |
| c10::str( |
| "gather requires a single-element output list containing a list with ", |
| getSize(), |
| " tensors.")); |
| } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, |
| c10::str( |
| "Incorrect output list size ", |
| inputTensors[0].size(), |
| ". Output list size should be ", |
| getSize(), |
| ", same as size of the process group.")); |
| } |
| |
| for (int i = 0; i < size_; i++) { |
| data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel(); |
| data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr(); |
| } |
| /* use scatter and store non-contiguous addresses in displacements to avoid |
| * flatten inputTensors */ |
| coll.src.info_v.buffer = nullptr; |
| coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
| coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
| coll.src.info_v.datatype = UCC_DT_UINT8; |
| coll.src.info_v.mem_type = |
| to_ucc_memType(inputTensors[0][0].device().type()); |
| |
| SAVE_TENSORS(inputTensors[0], data->src); |
| } else { |
| // for non-root ranks, inputTensors should be an empty list |
| if (inputTensors.size() != 0) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_COLL_POST, "requires empty output on non-root"); |
| } |
| } |
| |
| coll.dst.info.buffer = tensor.data_ptr(); |
| coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel(); |
| coll.dst.info.datatype = UCC_DT_UINT8; |
| coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
| SAVE_TENSORS(outputTensors, data->dst); |
| |
| return collective_post( |
| OpType::SCATTER, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| inputTensors[0], |
| outputTensors, |
| "ucc:scatter"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) { |
| check_tensor(tensors); |
| auto& tensor = tensors[0]; |
| initComm(tensor.device()); |
| |
| WorkData* data = new WorkData(); |
| ucc_coll_args_t coll; |
| coll.tag = tag; |
| coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_BCAST; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = to_ucc_dType(tensor); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.root = getRank(); |
| |
| coll.active_set.size = 2; |
| coll.active_set.start = getRank(); |
| coll.active_set.stride = dstRank - getRank(); |
| SAVE_TENSORS(tensors, data->dst); |
| |
| return collective_post( |
| OpType::SEND, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| tensors, |
| tensors, |
| "ucc:send"); |
| } |
| |
| c10::intrusive_ptr<Work> ProcessGroupUCC::recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) { |
| check_tensor(tensors); |
| auto& tensor = tensors[0]; |
| initComm(tensor.device()); |
| |
| WorkData* data = new WorkData(); |
| ucc_coll_args_t coll; |
| coll.tag = tag; |
| coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; |
| coll.flags = 0; |
| coll.coll_type = UCC_COLL_TYPE_BCAST; |
| coll.src.info.buffer = tensor.data_ptr(); |
| coll.src.info.count = tensor.numel(); |
| coll.src.info.datatype = to_ucc_dType(tensor); |
| coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
| coll.root = srcRank; |
| |
| coll.active_set.size = 2; |
| coll.active_set.start = srcRank; |
| coll.active_set.stride = getRank() - srcRank; |
| SAVE_TENSORS(tensors, data->dst); |
| |
| return collective_post( |
| OpType::RECV, |
| []() {}, |
| []() {}, |
| coll, |
| std::unique_ptr<WorkData>(data), |
| tensor.device(), |
| tensors, |
| tensors, |
| "ucc:recv"); |
| } |
| |
| void ProcessGroupUCC::setSequenceNumberForGroup() {} |
| |
| uint64_t ProcessGroupUCC::getSequenceNumberForGroup() { |
| return seq_; |
| } |
| |
| c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC( |
| const c10::intrusive_ptr<::c10d::Store>& store, |
| int rank, |
| int size, |
| const std::chrono::duration<float>& timeout) { |
| return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout); |
| } |
| |
| void ProcessGroupUCC::initComm(c10::Device dev) { |
| if (!comm) { |
| #ifdef USE_CUDA |
| if (dev.is_cuda()) { |
| c10::cuda::set_device(dev.index()); |
| } |
| #endif |
| comm = Comm::get_comm(comm_id, dev, oob, logger); |
| TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library"); |
| comm->ucc_create_team(team, oob); |
| TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library"); |
| logger->setPhase(TORCH_UCC_READY); |
| } else { |
| if (dev.is_cuda()) { |
| if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && |
| (comm->cuda_device_index != dev.index())) { |
| TORCH_UCC_LOG_ERROR( |
| TORCH_UCC_INIT, |
| "ucc communicator was initialized with different cuda device," |
| "multi device is not supported"); |
| throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
| } |
| comm->cuda_device_index = dev.index(); |
| } |
| } |
| #ifdef USE_CUDA |
| // Create UCC execution engine. |
| if (!cuda_ee && dev.is_cuda()) { |
| stream = std::make_unique<at::cuda::CUDAStream>( |
| at::cuda::getStreamFromPool(true, dev.index())); |
| ucc_ee_params_t params; |
| params.ee_type = UCC_EE_CUDA_STREAM; |
| params.ee_context = (void*)stream->stream(); |
| params.ee_context_size = sizeof(cudaStream_t); |
| TORCH_UCC_CHECK( |
| ucc_ee_create(team, ¶ms, &cuda_ee), |
| "failed to create UCC execution engine"); |
| } |
| #endif |
| } |
| |
| } // namespace c10d |
| |
| #endif // USE_C10D_UCC |