blob: fbceb81cf93222a5e39bfec3be825685360d92b6 [file] [log] [blame]
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <c10/util/CallOnce.h>
#include <c10/util/env.h>
#ifdef USE_C10D_NCCL
#include <mutex>
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
std::unique_lock<std::mutex> lock(mutex_);
if (aborted_) {
auto commFailureMsg = commFailureReason_ != c10::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)
: "";
TORCH_CHECK(
false,
c10::str(
"NCCL communicator was aborted on rank ",
rank_,
". ",
commFailureMsg));
}
return ncclComm_;
}
std::string getNcclVersion() {
static c10::once_flag ncclGetVersionFlag;
static std::string versionString;
c10::call_once(ncclGetVersionFlag, []() {
int version;
ncclResult_t status = ncclGetVersion(&version);
// can't compute the version if call did not return successfully or version
// code < 100 (corresponding to 0.1.0)
if (status != ncclSuccess || version < 100) {
versionString = "Unknown NCCL version";
} else {
// NCCL changed version coding starting 2.9
const int majorBase = version < 2900 ? 1000 : 10000;
const int minorBase = 100;
auto ncclMajor = version / majorBase;
auto ncclMinor = (version % majorBase) / minorBase;
auto ncclPatch =
version % (ncclMajor * majorBase + ncclMinor * minorBase);
versionString = std::to_string(ncclMajor) + "." +
std::to_string(ncclMinor) + "." + std::to_string(ncclPatch);
}
});
return versionString;
}
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;
if (nccl_use_nonblocking_) {
TORCH_WARN("Using experimental non-blocking NCCL communicator.");
}
return nccl_use_nonblocking_;
}
int _parse_nccl_nonblocking_timeout() {
const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT");
int timeout = -1;
if (val) {
const std::string config(val);
timeout = std::stoi(config);
if (!nccl_use_nonblocking() && timeout > 0) {
TORCH_WARN(
"TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false.");
timeout = -1;
}
}
return timeout;
}
int nccl_nonblocking_timeout() {
static int timeout = _parse_nccl_nonblocking_timeout();
return timeout;
}
std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();
}
// Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase.
std::string getNcclErrorDetailStr(
ncclResult_t error,
c10::optional<std::string> processGroupFailureReason /* = c10::nullopt */
) {
// Prioritize failure reason provided by PG NCCL first, as it can abort
// communicators when it encounters collective timeouts, etc.
if (processGroupFailureReason != c10::nullopt) {
return *processGroupFailureReason;
}
std::string interpret;
std::string err;
#ifdef ENABLE_NCCL_GET_LAST_ERROR
err = "\nLast error:\n" + std::string(ncclGetLastError(NULL));
#endif
switch (error) {
case ncclUnhandledCudaError:
interpret = "ncclUnhandledCudaError: Call to CUDA function failed.";
break;
case ncclSystemError:
interpret =
"ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. ";
#ifndef NCCL_REMOTE_ERROR
// Before ncclRemoteError was created, unexpected remote disconnect was
// categorized as ncclSystemError
interpret += "It can be also caused by unexpected exit of a remote peer.";
#endif
break;
case ncclInternalError:
interpret = "ncclInternalError: Internal check failed.";
break;
case ncclInvalidArgument:
interpret = "ncclInvalidArgument: Invalid value for an argument.";
break;
case ncclInvalidUsage:
interpret =
"ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
break;
#ifdef NCCL_REMOTE_ERROR
case ncclRemoteError:
interpret =
"ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.";
break;
#endif
default:
interpret = "Unknown NCCL error!";
}
return interpret + err;
}
} // namespace c10d
#endif // USE_C10D_NCCL