blob: 36034114d7516bd649c53a33986b549ae147f542 [file] [log] [blame]
#pragma once
#ifdef USE_C10D_NCCL
#include <stdio.h>
#include <stdlib.h>
#include <memory>
#include <mutex>
#include <nccl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
namespace {
// Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase.
std::string getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
// Prioritize failure reason provided by PG NCCL first, as it can abort
// communicators when it encounters collective timeouts, etc.
if (processGroupFailureReason != c10::nullopt) {
return (*processGroupFailureReason).c_str();
}
std::string interpret;
std::string err;
#ifdef ENABLE_NCCL_GET_LAST_ERROR
err = "\nLast error:\n" + std::string(ncclGetLastError(NULL));
#endif
switch (error) {
case ncclUnhandledCudaError:
interpret = "ncclUnhandledCudaError: Call to CUDA function failed.";
break;
case ncclSystemError:
interpret = "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. "
"It can be also caused by unexpected exit of a remote peer.";
break;
case ncclInternalError:
interpret = "ncclInternalError: Internal check failed.";
break;
case ncclInvalidArgument:
interpret = "ncclInvalidArgument: Invalid value for an argument.";
break;
case ncclInvalidUsage:
interpret = "ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
break;
default:
interpret = "Unknown NCCL error!";
}
return interpret + err;
}
} // namespace
// ncclGetLastError() is enabled only for NCCL versions 2.13+
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 13)
#define ENABLE_NCCL_GET_LAST_ERROR
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_GET_LAST_ERROR
#endif
// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
// and ncclCommGetAsyncError() are not supported in earlier versions.
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 4)
#define ENABLE_NCCL_ERROR_CHECKING
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_ERROR_CHECKING
#endif
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
// and ncclRecv() are not supported in earlier versions.
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
#define ENABLE_NCCL_P2P_SUPPORT
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_P2P_SUPPORT
#endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 11)
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
ncclResult_t result = cmd; \
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK(false, err); \
} \
} while (0)
// Macro to print and abort on a non-successful NCCL return value.
#define C10D_NCCL_ASSERT(cmd) \
do { \
ncclResult_t result = cmd; \
if (result != ncclSuccess) { \
std::string err = ncclGetErrorWithVersion(result); \
fprintf( \
stderr, \
"NCCL error in: %s:%d, %s\n", \
__FILE__, \
__LINE__, \
err.c_str()); \
abort(); \
} \
} while (0)
namespace c10d {
std::string getNcclVersion();
std::string ncclGetErrorWithVersion(ncclResult_t error);
// RAII wrapper for NCCL communicator
class NCCLComm {
public:
explicit NCCLComm(ncclComm_t ncclComm)
: ncclComm_(ncclComm),
aborted_(false),
ncclAsyncErr_(ncclSuccess),
commFailureReason_(c10::nullopt) {}
NCCLComm() : NCCLComm(nullptr) {}
~NCCLComm() noexcept {
// Add lock in this destructor, as aborted_ needs to be read after memory
// barrier here.
std::unique_lock<std::mutex> lock(mutex_);
if (ncclComm_ && !aborted_) {
#ifdef ENABLE_NCCL_ERROR_CHECKING
// Use ncclCommAbort instead of ncclCommDestroy here since
// ncclCommDestroy could block forever waiting for work to complete on
// the communicator.
C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
#else
C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
#endif
}
}
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
return comm;
}
ncclUniqueId getNcclId() {
return ncclId_;
}
// Must not be copyable
NCCLComm(const NCCLComm&) = delete;
NCCLComm& operator=(const NCCLComm&) = delete;
// Do not support move assignment as there is no valid use case
NCCLComm& operator=(NCCLComm&& other) = delete;
// Move constructable
NCCLComm(NCCLComm&& other) {
// Using other's lock, as it reads other's states
// Can not use this.mutex_, as this object is being constructed.
std::unique_lock<std::mutex> lock(other.mutex_);
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
}
ncclComm_t getNcclComm();
c10::optional<std::string> getNcclCommFailureReason() const {
std::unique_lock<std::mutex> lock(mutex_);
return commFailureReason_;
}
void ncclCommAbort(
c10::optional<std::string> commFailureReason = c10::nullopt) {
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (aborted_) {
// Should not abort twice.
return;
}
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
// timeout)
commFailureReason_ = commFailureReason;
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
aborted_ = true;
ncclComm_ = nullptr;
// Set an appropriate error so that we avoid using the communicator.
if (ncclAsyncErr_ == ncclSuccess) {
ncclAsyncErr_ = ncclSystemError;
}
#else
// This is a NOOP, if error checks are disabled.
return;
#endif
}
bool isAborted() const {
std::unique_lock<std::mutex> lock(mutex_);
return aborted_;
}
ncclResult_t checkForNcclError() {
std::unique_lock<std::mutex> lock(mutex_);
#ifdef ENABLE_NCCL_ERROR_CHECKING
if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_;
}
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_;
#else
// Always return success, if error checks are disabled.
return ncclSuccess;
#endif
}
protected:
ncclComm_t ncclComm_;
// Unique nccl_id for this communicator.
ncclUniqueId ncclId_;
bool aborted_;
ncclResult_t ncclAsyncErr_;
mutable std::mutex mutex_;
// Rank that this communicator corresponds to.
int rank_;
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
// better error messaging.
c10::optional<std::string> commFailureReason_;
};
// Helper that automatically cleans up premul sums.
struct ncclRedOpRAII {
ncclRedOpRAII() {}
ncclRedOpRAII(ncclRedOp_t op) : op_(op) {}
ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) :
op_(op), comm_(comm), premul_sum_(true) {}
ncclRedOpRAII(const ncclRedOpRAII&) = delete;
ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete;
ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() {
std::swap(tmp.op_, this->op_);
std::swap(tmp.comm_, this->comm_);
std::swap(tmp.premul_sum_, this->premul_sum_);
}
#if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT)
~ncclRedOpRAII() {
if (premul_sum_) {
ncclRedOpDestroy(op_, comm_);
}
}
#endif
operator ncclRedOp_t() const { return op_; }
ncclRedOp_t op_;
ncclComm_t comm_;
bool premul_sum_ = false;
};
} // namespace c10d
#endif // USE_C10D_NCCL