| #ifndef CAFFE2_MPI_MPI_OPS_H_ |
| #define CAFFE2_MPI_MPI_OPS_H_ |
| |
| #include <mpi.h> |
| |
| #include "caffe2/core/operator.h" |
| #include "caffe2/mpi/mpi_common.h" |
| |
| namespace caffe2 { |
| |
| // TODO(jiayq): if needed, write up the use of color and key with MPI split. |
| // Currently, the operator simply creates a communicator that has the |
| // same topology as the Caffe2 global communicator. |
| template <class Context> |
| class MPICreateCommonWorldOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MPICreateCommonWorldOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws) {} |
| |
| bool RunOnDevice() override { |
| OperatorBase::Outputs()[0]->Reset(new MPICommonWorldWrapper()); |
| return true; |
| } |
| }; |
| |
| template <class Context> |
| class MPIBroadcastOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MPIBroadcastOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {} |
| ~MPIBroadcastOp() {} |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm(); |
| CAFFE_ENFORCE( |
| OperatorBase::OutputIsTensorType(0, Context::GetDeviceType()), |
| "Output is of wrong type."); |
| auto* output = Output(0); |
| // Make sure that output is already allocated. |
| CAFFE_ENFORCE( |
| output->numel() > 0, |
| "Broadcast op uses in-place operation so the output " |
| "should be already allocated."); |
| MPI_CHECK(MPI_Bcast( |
| output->raw_mutable_data(), |
| output->nbytes(), |
| MPIDataTypeWrapper<char>::type(), |
| root_, |
| comm)); |
| return true; |
| } |
| |
| protected: |
| int root_; |
| }; |
| |
| // MPIReduceOp does Reduce using MPI. Currently, only SUM is supported. |
| template <typename T, class Context> |
| class MPIReduceOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MPIReduceOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {} |
| ~MPIReduceOp() {} |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm(); |
| auto& input = Input(1); |
| auto* output = Output(0, input.sizes(), at::dtype<T>()); |
| MPI_CHECK(MPI_Reduce( |
| const_cast<T*>(input.template data<T>()), |
| output->template mutable_data<T>(), |
| input.numel(), |
| MPIDataTypeWrapper<T>::type(), |
| MPI_SUM, |
| root_, |
| comm)); |
| return true; |
| } |
| |
| protected: |
| int root_; |
| }; |
| |
| // MPIAllgatherOp does MPIAllgather using MPI. |
| template <typename T, class Context> |
| class MPIAllgatherOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| USE_SIMPLE_CTOR_DTOR(MPIAllgatherOp); |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm(); |
| auto& input = Input(1); |
| auto* output = Output(0); |
| vector<int64_t> output_dims = input.sizes().vec(); |
| output_dims[0] *= OperatorBase::Input<MPICommonWorldWrapper>(0).size(); |
| output->Resize(output_dims); |
| MPI_CHECK(MPI_Allgather( |
| const_cast<T*>(input.template data<T>()), |
| input.numel(), |
| MPIDataTypeWrapper<T>::type(), |
| output->template mutable_data<T>(), |
| input.numel(), |
| MPIDataTypeWrapper<T>::type(), |
| comm)); |
| return true; |
| } |
| }; |
| |
| // MPIAllreduceOp does MPIAllreduce using MPI. Currently, only SUM is supported. |
| template <typename T, class Context> |
| class MPIAllreduceOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| USE_SIMPLE_CTOR_DTOR(MPIAllreduceOp); |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm(); |
| auto& input = Input(1); |
| auto* output = Output(0, input.sizes(), at::dtype<T>()); |
| void* source; |
| if (output->template mutable_data<T>() == input.template data<T>()) { |
| // We are doing in-place call. Special case handling. |
| source = MPI_IN_PLACE; |
| } else { |
| // Normal allreduce takes the source from the input. |
| source = const_cast<T*>(input.template data<T>()); |
| } |
| MPI_CHECK(MPI_Allreduce( |
| source, |
| output->template mutable_data<T>(), |
| input.numel(), |
| MPIDataTypeWrapper<T>::type(), |
| MPI_SUM, |
| comm)); |
| return true; |
| } |
| }; |
| |
| template <class Context> |
| class MPISendTensorOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MPISendTensorOp(const OperatorDef& def, Workspace* ws) |
| : Operator<Context>(def, ws), |
| OP_SINGLE_ARG(int, "dst", dst_, MPI_ANY_SOURCE), |
| OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG), |
| OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) { |
| CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet."); |
| CAFFE_ENFORCE( |
| dst_ != MPI_ANY_SOURCE || def.input_size() == 4, |
| "You should explicitly specify the to rank either via " |
| "argument or via input blobs."); |
| CAFFE_ENFORCE( |
| tag_ != MPI_ANY_TAG || def.input_size() == 4, |
| "You should explicitly specify the tag either via " |
| "argument or via input blobs."); |
| } |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm(); |
| auto& input = Input(INPUT); |
| if (InputSize() == 4) { |
| dst_ = OperatorBase::Input<Tensor>(DST, CPU).template data<int>()[0]; |
| tag_ = OperatorBase::Input<Tensor>(TAG, CPU).template data<int>()[0]; |
| } |
| if (raw_buffer_) { |
| // We need to do a const cast to cope with the fact that, before OpenMPI |
| // 1.7, MPI_Send expects a non-const pointer although it uses it in a |
| // const way. |
| MPI_CHECK(MPI_Send( |
| const_cast<void*>(input.raw_data()), |
| input.nbytes(), |
| MPI_CHAR, |
| dst_, |
| tag_, |
| comm)); |
| } else { |
| CAFFE_NOT_IMPLEMENTED; |
| } |
| return true; |
| } |
| |
| protected: |
| int dst_; |
| int tag_; |
| bool raw_buffer_; |
| |
| INPUT_TAGS(COMM, INPUT, DST, TAG); |
| }; |
| |
| template <class Context> |
| class MPIReceiveTensorOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| MPIReceiveTensorOp(const OperatorDef& def, Workspace* ws) |
| : Operator<Context>(def, ws), |
| OP_SINGLE_ARG(int, "src", src_, MPI_ANY_SOURCE), |
| OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG), |
| OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) { |
| CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet."); |
| } |
| |
| bool RunOnDevice() override { |
| MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm(); |
| if (InputSize() == 4) { |
| src_ = OperatorBase::Input<Tensor>(SRC_IN, CPU).template data<int>()[0]; |
| tag_ = OperatorBase::Input<Tensor>(TAG_IN, CPU).template data<int>()[0]; |
| } |
| MPI_Status status; |
| if (raw_buffer_) { |
| auto* output = Output(OUTPUT); |
| MPI_CHECK(MPI_Recv( |
| output->raw_mutable_data(), |
| output->nbytes(), |
| MPI_CHAR, |
| src_, |
| tag_, |
| comm, |
| &status)); |
| } else { |
| CAFFE_NOT_IMPLEMENTED; |
| } |
| auto* src_out = OperatorBase::Output<Tensor>(SRC_OUT, CPU); |
| src_out->Resize(); |
| src_out->template mutable_data<int>()[0] = status.MPI_SOURCE; |
| auto* tag_out = OperatorBase::Output<Tensor>(TAG_OUT, CPU); |
| tag_out->Resize(); |
| tag_out->template mutable_data<int>()[0] = status.MPI_TAG; |
| return true; |
| } |
| |
| protected: |
| int src_; |
| int tag_; |
| bool raw_buffer_; |
| INPUT_TAGS(COMM, INPUT, SRC_IN, TAG_IN); |
| OUTPUT_TAGS(OUTPUT, SRC_OUT, TAG_OUT); |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_MPI_MPI_OPS_H_ |