blob: ab04afb86975616d61725d3e7d7d0c77b6b30acb [file] [log] [blame]
#ifndef CAFFE2_MPI_MPI_COMMON_H_
#define CAFFE2_MPI_MPI_COMMON_H_
#include <mpi.h>
#include <mutex>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
namespace caffe2 {
inline void CheckInitializedMPI() {
int flag;
MPI_Initialized(&flag);
CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
}
template <typename T>
class MPIDataTypeWrapper;
#define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
template <> \
class MPIDataTypeWrapper<c_type> { \
public: \
inline static MPI_Datatype type() { \
return mpi_type; \
} \
};
MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
// Note(Yangqing): as necessary, add more specializations.
#undef MPI_DATATYPE_WRAPPER
// For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
TORCH_API std::mutex& MPIMutex();
#define MPI_CHECK(condition) \
do { \
std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
int error = (condition); \
CAFFE_ENFORCE( \
error == MPI_SUCCESS, \
"Caffe2 MPI Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
error); \
} while (0)
/**
* @brief Gets the global MPI communicator used by Caffe2. In default, this
* is MPI_COMM_WORLD unless you call SetGlobalMPIComm().
*/
TORCH_API MPI_Comm GlobalMPIComm();
/**
* @brief Sets the global MPI communicator. Caffe2 takes over the ownership
* of the passed in communicator.
*/
TORCH_API void SetGlobalMPIComm(MPI_Comm new_comm);
/**
* @brief A helper function to return the size of the given communicator.
*/
TORCH_API int MPICommSize(MPI_Comm comm);
/**
* @brief A helper function to return the rank of the given communicator.
*/
TORCH_API int MPICommRank(MPI_Comm comm);
/**
* @brief A simple wrapper over an MPI common world.
*/
class MPICommonWorldWrapper {
public:
/**
* @brief Creates a common world wrapper.
*
* The new common world is created by taking the existing communicator
* passed in as src_comm, and splitting it using the color and the rank
* specified. In default, we will split from Caffe2's global communicator,
* and use color 0 as well as rank implicitly given by src_comm. As a result,
* the default constructor basically creates a comm identical to the source
* comm world.
*/
explicit MPICommonWorldWrapper(
MPI_Comm src_comm = MPI_COMM_NULL,
int color = 0,
int rank = -1) {
if (src_comm == MPI_COMM_NULL) {
src_comm = GlobalMPIComm();
}
if (rank == -1) {
MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
}
MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
MPI_CHECK(MPI_Comm_size(comm_, &size_));
MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
}
~MPICommonWorldWrapper() {
int ret;
MPI_CHECK(MPI_Finalized(&ret));
if (!ret) {
MPI_Comm_free(&comm_);
}
}
/**
* @brief Returns the common world held by the wrapper.
*/
inline MPI_Comm comm() const {
return comm_;
}
/**
* @brief Returns the size of the world.
*/
inline int size() const {
return size_;
}
/**
* @brief Returns the rank of this process in the world.
*/
inline int rank() const {
return rank_;
}
private:
MPI_Comm comm_;
int size_;
int rank_;
};
/**
* A function used to perform peer setup so one does not need to use
* mpirun / mpiexec to run the binary. Note that if you use mpirun or mpiexec
* to set up the common world, do not use this function - MPI_Init would have
* already set that up.
*
* This also assumes that you have a common path (like NFS) that multiple
* instances can read from.
*
* Inputs:
* replicas (int): the number of replicas that mpi will run with.
* role (string): the role of this process, "server" or "client".
* job_path (string): a file name that the server will write its port into
* and the clients will read the server's port from.
*/
void MPISetupPeers(
const int replicas,
const string& role,
const string& job_path);
} // namespace caffe2
#endif // CAFFE2_MPI_MPI_COMMON_H_