blob: 1af0f76dc6068f6e5cd22672d840654746b76c32 [file] [log] [blame]
#include "caffe2/mpi/mpi_common.h"
#include <thread>
#include <c10/util/typeid.h>
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
CAFFE_KNOWN_TYPE(MPICommonWorldWrapper);
static std::mutex gCaffe2MPIMutex;
std::mutex& MPIMutex() {
return gCaffe2MPIMutex;
}
static MPI_Comm gCaffe2MPIComm = MPI_COMM_WORLD;
MPI_Comm GlobalMPIComm() {
return gCaffe2MPIComm;
}
void SetGlobalMPIComm(MPI_Comm new_comm) {
if (gCaffe2MPIComm != MPI_COMM_WORLD) {
MPI_Comm_free(&gCaffe2MPIComm);
}
gCaffe2MPIComm = new_comm;
}
int MPICommSize(MPI_Comm comm) {
int comm_size;
MPI_CHECK(MPI_Comm_size(comm, &comm_size));
return comm_size;
}
int MPICommRank(MPI_Comm comm) {
int comm_rank;
MPI_CHECK(MPI_Comm_rank(comm, &comm_rank));
return comm_rank;
}
/**
* Helper function used to setup MPI intercommunicator.
*/
static MPI_Comm AssimilateComm(MPI_Comm intra, MPI_Comm inter) {
MPI_Comm peer = MPI_COMM_NULL;
MPI_Comm newInterComm = MPI_COMM_NULL;
MPI_Comm newIntraComm = MPI_COMM_NULL;
// The spawned rank will be the "high" rank in the new intra-comm
int high = (MPI_COMM_NULL == intra) ? 1 : 0;
// If this is one of the (two) ranks in the inter-comm,
// create a new intra-comm from the inter-comm
if (MPI_COMM_NULL != inter) {
MPI_CHECK(MPI_Intercomm_merge(inter, high, &peer));
} else {
peer = MPI_COMM_NULL;
}
// Create a new inter-comm between the pre-existing intra-comm
// (all of it, not only rank zero), and the remote (spawned) rank,
// using the just-created intra-comm as the peer communicator.
int tag = 12345;
if (MPI_COMM_NULL != intra) {
// This task is a member of the pre-existing intra-comm
MPI_CHECK(MPI_Intercomm_create(intra, 0, peer, 1, tag, &newInterComm));
} else {
// This is the remote (spawned) task
MPI_CHECK(
MPI_Intercomm_create(MPI_COMM_SELF, 0, peer, 0, tag, &newInterComm));
}
// Now convert this inter-comm into an intra-comm
MPI_CHECK(MPI_Intercomm_merge(newInterComm, high, &newIntraComm));
// Clean up the intermediaries
if (MPI_COMM_NULL != peer) {
MPI_CHECK(MPI_Comm_free(&peer));
}
MPI_CHECK(MPI_Comm_free(&newInterComm));
// Delete the original intra-comm
if (MPI_COMM_NULL != intra && MPI_COMM_WORLD != intra &&
GlobalMPIComm() != intra) {
MPI_CHECK(MPI_Comm_free(&intra));
}
// Return the new intra-comm
return newIntraComm;
}
void MPISetupPeers(
const int replicas,
const string& role,
const string& job_path) {
int flag;
MPI_Initialized(&flag);
if (!flag) {
int mpi_ret;
MPI_CHECK(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpi_ret));
if (mpi_ret != MPI_THREAD_MULTIPLE && mpi_ret != MPI_THREAD_SERIALIZED) {
LOG(FATAL) << "This test requires the underlying MPI to support the "
<< "MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.";
return;
}
}
if (MPICommSize(MPI_COMM_WORLD) != 1) {
LOG(ERROR) << "MPI_COMM_WORLD size is not 1: did you already run "
"MPISetupPeers? Note that if you execute your program with "
"mpirun to launch multiple local processes, you should not "
"call MPISetupPeers.";
return;
}
if (role == "server") {
// Open a port to accept connections.
char port_name[MPI_MAX_PORT_NAME] = {'\0'};
MPI_CHECK(MPI_Open_port(MPI_INFO_NULL, port_name));
VLOG(1) << "MPI server: port: " << port_name;
// Writes the port name to the file.
CHECK(WriteStringToFile(std::string(port_name), job_path.c_str()));
VLOG(1) << "MPI server: wrote to file: " << job_path;
int comm_size = MPICommSize(GlobalMPIComm());
while (comm_size < replicas) {
MPI_Comm icomm;
VLOG(1) << "MPI server: waiting for client "
<< "(" << comm_size << "/" << replicas << " have connected)";
MPI_CHECK(
MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &icomm));
VLOG(1) << "MPI server: accepted client";
MPI_Comm new_intra_comm = AssimilateComm(GlobalMPIComm(), icomm);
SetGlobalMPIComm(new_intra_comm);
comm_size = MPICommSize(new_intra_comm);
}
} else {
// Opens the job path file to obtain server address.
std::string port_name;
while (!ReadStringFromFile(job_path.c_str(), &port_name) ||
port_name.length() == 0) {
/* sleep override */
std::this_thread::sleep_for(std::chrono::seconds(1));
}
// Connect to server.
MPI_Comm icomm;
VLOG(1) << "MPI client: connecting to port: " << port_name;
MPI_CHECK(MPI_Comm_connect(
const_cast<char*>(port_name.c_str()),
MPI_INFO_NULL,
0,
GlobalMPIComm(),
&icomm));
VLOG(1) << "MPI client: connected";
// Join the server's reference intracommunicator.
MPI_Comm new_intra_comm = AssimilateComm(MPI_COMM_NULL, icomm);
SetGlobalMPIComm(new_intra_comm);
// Let other clients join the intracommunicator we're now a part of.
while (MPICommSize(GlobalMPIComm()) < replicas) {
MPI_Comm comm = AssimilateComm(GlobalMPIComm(), MPI_COMM_NULL);
SetGlobalMPIComm(comm);
}
}
// After all peers have assimilated, do a barrier.
MPI_Barrier(GlobalMPIComm());
VLOG(1) << "MPI using a communicator of size: "
<< MPICommSize(GlobalMPIComm());
}
} // namespace caffe2