blob: 1157e78eab73f750df3acd3c739c7044f0786e4a [file] [log] [blame]
#ifndef CAFFE2_MPI_MPI_OPS_H_
#define CAFFE2_MPI_MPI_OPS_H_
#include <mpi.h>
#include "caffe2/core/operator.h"
#include "caffe2/mpi/mpi_common.h"
namespace caffe2 {
// TODO(jiayq): if needed, write up the use of color and key with MPI split.
// Currently, the operator simply creates a communicator that has the
// same topology as the Caffe2 global communicator.
template <class Context>
class MPICreateCommonWorldOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MPICreateCommonWorldOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override {
OperatorBase::Outputs()[0]->Reset(new MPICommonWorldWrapper());
return true;
}
};
template <class Context>
class MPIBroadcastOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MPIBroadcastOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
~MPIBroadcastOp() {}
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
CAFFE_ENFORCE(
OperatorBase::OutputIsTensorType(0, Context::GetDeviceType()),
"Output is of wrong type.");
auto* output = Output(0);
// Make sure that output is already allocated.
CAFFE_ENFORCE(
output->numel() > 0,
"Broadcast op uses in-place operation so the output "
"should be already allocated.");
MPI_CHECK(MPI_Bcast(
output->raw_mutable_data(),
output->nbytes(),
MPIDataTypeWrapper<char>::type(),
root_,
comm));
return true;
}
protected:
int root_;
};
// MPIReduceOp does Reduce using MPI. Currently, only SUM is supported.
template <typename T, class Context>
class MPIReduceOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MPIReduceOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
~MPIReduceOp() {}
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
auto& input = Input(1);
auto* output = Output(0, input.sizes(), at::dtype<T>());
MPI_CHECK(MPI_Reduce(
const_cast<T*>(input.template data<T>()),
output->template mutable_data<T>(),
input.numel(),
MPIDataTypeWrapper<T>::type(),
MPI_SUM,
root_,
comm));
return true;
}
protected:
int root_;
};
// MPIAllgatherOp does MPIAllgather using MPI.
template <typename T, class Context>
class MPIAllgatherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(MPIAllgatherOp);
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
auto& input = Input(1);
auto* output = Output(0);
vector<int64_t> output_dims = input.sizes().vec();
output_dims[0] *= OperatorBase::Input<MPICommonWorldWrapper>(0).size();
output->Resize(output_dims);
MPI_CHECK(MPI_Allgather(
const_cast<T*>(input.template data<T>()),
input.numel(),
MPIDataTypeWrapper<T>::type(),
output->template mutable_data<T>(),
input.numel(),
MPIDataTypeWrapper<T>::type(),
comm));
return true;
}
};
// MPIAllreduceOp does MPIAllreduce using MPI. Currently, only SUM is supported.
template <typename T, class Context>
class MPIAllreduceOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(MPIAllreduceOp);
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
auto& input = Input(1);
auto* output = Output(0, input.sizes(), at::dtype<T>());
void* source;
if (output->template mutable_data<T>() == input.template data<T>()) {
// We are doing in-place call. Special case handling.
source = MPI_IN_PLACE;
} else {
// Normal allreduce takes the source from the input.
source = const_cast<T*>(input.template data<T>());
}
MPI_CHECK(MPI_Allreduce(
source,
output->template mutable_data<T>(),
input.numel(),
MPIDataTypeWrapper<T>::type(),
MPI_SUM,
comm));
return true;
}
};
template <class Context>
class MPISendTensorOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MPISendTensorOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
OP_SINGLE_ARG(int, "dst", dst_, MPI_ANY_SOURCE),
OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
CAFFE_ENFORCE(
dst_ != MPI_ANY_SOURCE || def.input_size() == 4,
"You should explicitly specify the to rank either via "
"argument or via input blobs.");
CAFFE_ENFORCE(
tag_ != MPI_ANY_TAG || def.input_size() == 4,
"You should explicitly specify the tag either via "
"argument or via input blobs.");
}
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
auto& input = Input(INPUT);
if (InputSize() == 4) {
dst_ = OperatorBase::Input<Tensor>(DST, CPU).template data<int>()[0];
tag_ = OperatorBase::Input<Tensor>(TAG, CPU).template data<int>()[0];
}
if (raw_buffer_) {
// We need to do a const cast to cope with the fact that, before OpenMPI
// 1.7, MPI_Send expects a non-const pointer although it uses it in a
// const way.
MPI_CHECK(MPI_Send(
const_cast<void*>(input.raw_data()),
input.nbytes(),
MPI_CHAR,
dst_,
tag_,
comm));
} else {
CAFFE_NOT_IMPLEMENTED;
}
return true;
}
protected:
int dst_;
int tag_;
bool raw_buffer_;
INPUT_TAGS(COMM, INPUT, DST, TAG);
};
template <class Context>
class MPIReceiveTensorOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MPIReceiveTensorOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
OP_SINGLE_ARG(int, "src", src_, MPI_ANY_SOURCE),
OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
}
bool RunOnDevice() override {
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
if (InputSize() == 4) {
src_ = OperatorBase::Input<Tensor>(SRC_IN, CPU).template data<int>()[0];
tag_ = OperatorBase::Input<Tensor>(TAG_IN, CPU).template data<int>()[0];
}
MPI_Status status;
if (raw_buffer_) {
auto* output = Output(OUTPUT);
MPI_CHECK(MPI_Recv(
output->raw_mutable_data(),
output->nbytes(),
MPI_CHAR,
src_,
tag_,
comm,
&status));
} else {
CAFFE_NOT_IMPLEMENTED;
}
auto* src_out = OperatorBase::Output<Tensor>(SRC_OUT, CPU);
src_out->Resize();
src_out->template mutable_data<int>()[0] = status.MPI_SOURCE;
auto* tag_out = OperatorBase::Output<Tensor>(TAG_OUT, CPU);
tag_out->Resize();
tag_out->template mutable_data<int>()[0] = status.MPI_TAG;
return true;
}
protected:
int src_;
int tag_;
bool raw_buffer_;
INPUT_TAGS(COMM, INPUT, SRC_IN, TAG_IN);
OUTPUT_TAGS(OUTPUT, SRC_OUT, TAG_OUT);
};
} // namespace caffe2
#endif // CAFFE2_MPI_MPI_OPS_H_