| #ifndef CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ |
| #define CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ |
| |
| #include <iterator> |
| #include <string> |
| #include <tuple> |
| #include <vector> |
| |
| #include "caffe2/core/common_omp.h" |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/core/tensor.h" |
| #include "caffe2/operators/elementwise_ops_utils.h" |
| #include "caffe2/utils/eigen_utils.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| using NumericTypes = TensorTypes<int32_t, int64_t, float, double>; |
| using IntTypes = TensorTypes<int32_t, int64_t>; |
| using BoolTypes = TensorTypes<bool>; |
| using IntBoolTypes = TensorTypes<int32_t, int64_t, bool>; // discrete types |
| |
| struct SameTypeAsInput { |
| template <typename T> |
| using type = T; |
| }; |
| |
| template <typename R> |
| struct FixedType { |
| template <typename T> |
| using type = R; |
| }; |
| |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput> |
| class UnaryElementwiseWithArgsOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit UnaryElementwiseWithArgsOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), functor_(*this) {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<InputTypes>::call(this, Input(0)); |
| } |
| |
| template <typename T> |
| bool DoRunWithType() { |
| const auto& X = Input(0); |
| |
| auto* Y = Output( |
| 0, X.sizes(), at::dtype<typename OutputTypeMap::template type<T>>()); |
| return functor_( |
| X.numel(), |
| X.template data<T>(), |
| Y->template mutable_data<typename OutputTypeMap::template type<T>>(), |
| &context_); |
| } |
| |
| private: |
| Functor functor_; |
| }; |
| |
| // UnaryFunctorWithDefaultCtor is a functor that can be used as the functor of |
| // an UnaryElementwiseWithArgsOp. It simply forwards the operator() call into |
| // another functor that doesn't accept arguments in its constructor. |
| template <class Functor> |
| struct UnaryFunctorWithDefaultCtor { |
| explicit UnaryFunctorWithDefaultCtor(OperatorBase& /* op */) {} |
| |
| template <typename TIn, typename TOut, class Context> |
| bool operator()(const int size, const TIn* X, TOut* Y, Context* context) |
| const { |
| return functor(size, X, Y, context); |
| } |
| |
| Functor functor{}; |
| }; |
| |
| // UnaryElementwiseOp is a wrapper around UnaryElementwiseWithArgsOp, with the |
| // difference that it takes a functor with default constructor, e.g. that does |
| // not need to take into consideration any arguments during operator creation. |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput> |
| using UnaryElementwiseOp = UnaryElementwiseWithArgsOp< |
| InputTypes, |
| Context, |
| UnaryFunctorWithDefaultCtor<Functor>, |
| OutputTypeMap>; |
| |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput> |
| class BinaryElementwiseWithArgsOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit BinaryElementwiseWithArgsOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false), |
| OP_SINGLE_ARG(int, "axis", axis_, -1), |
| OP_SINGLE_ARG(string, "axis_str", axis_str_, string("")), |
| OP_SINGLE_ARG(string, "order", order_, "NCHW"), |
| functor_(*this) { |
| if (legacy_broadcast_) { |
| if (axis_ != -1) { |
| // Get axis from an explicit axis argument. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), |
| 0U, |
| "Args axis and axis_str cannot be used simultaneously."); |
| } else if (axis_str_.size()) { |
| // Get the axis index semantically. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), 1U, "Unsupported axis string", axis_str_); |
| const size_t semantic_axis_ = order_.find(axis_str_); |
| CAFFE_ENFORCE_NE( |
| semantic_axis_, |
| string::npos, |
| "Unrecognizable axis string ", |
| axis_str_, |
| " from order string ", |
| order_); |
| axis_ = semantic_axis_; |
| } else { |
| CAFFE_ENFORCE( |
| axis_ == -1 && axis_str_.empty(), |
| "Do not specify axis or axis_str if broadcast is not enabled."); |
| } |
| } |
| } |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<InputTypes>::call(this, Input(0)); |
| } |
| |
| template <typename T> |
| bool DoRunWithType() { |
| const auto& A = Input(0); |
| const auto& B = Input(1); |
| |
| const T* A_data = A.template data<T>(); |
| const T* B_data = B.template data<T>(); |
| std::vector<int> A_dims; |
| std::vector<int> B_dims; |
| std::vector<int64_t> C_dims; |
| |
| if (legacy_broadcast_) { |
| CAFFE_ENFORCE( |
| !IsInputOutputAlias(1, 0), |
| "In-place is allowed only with the first tensor when " |
| "legacy-broadcasting"); |
| C_dims = A.sizes().vec(); |
| if (B.numel() == 1) { |
| A_dims = {static_cast<int>(A.numel())}; |
| B_dims = {1}; |
| } else { |
| size_t pre, n, post; |
| std::tie(pre, n, post) = |
| elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_); |
| A_dims = { |
| static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)}; |
| B_dims = {static_cast<int>(n), 1}; |
| } |
| } else { |
| std::copy( |
| A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims)); |
| std::copy( |
| B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims)); |
| // TODO: change the types to vector<int64_t> |
| auto C_dims_int = |
| elementwise_ops_utils::ComputeBinaryBroadcastForwardDims( |
| A_dims, B_dims); |
| std::copy( |
| C_dims_int.cbegin(), C_dims_int.cend(), std::back_inserter(C_dims)); |
| if (IsInputOutputAlias(0, 0)) { |
| CAFFE_ENFORCE_EQ(C_dims_int, A_dims); |
| } else if (IsInputOutputAlias(1, 0)) { |
| CAFFE_ENFORCE_EQ(C_dims_int, B_dims); |
| } |
| } |
| |
| auto* C = Output( |
| 0, C_dims, at::dtype<typename OutputTypeMap::template type<T>>()); |
| auto* C_data = |
| C->template mutable_data<typename OutputTypeMap::template type<T>>(); |
| return functor_.Forward(A_dims, B_dims, A_data, B_data, C_data, &context_); |
| } |
| |
| private: |
| const bool legacy_broadcast_; |
| int axis_; |
| const std::string axis_str_; |
| const std::string order_; |
| |
| Functor functor_; |
| }; |
| |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput, |
| class GradientTypeMap = SameTypeAsInput> |
| class BinaryElementwiseWithArgsGradientOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit BinaryElementwiseWithArgsGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false), |
| OP_SINGLE_ARG(int, "axis", axis_, -1), |
| OP_SINGLE_ARG(string, "axis_str", axis_str_, ""), |
| OP_SINGLE_ARG(string, "order", order_, "NCHW"), |
| functor_(*this) { |
| if (legacy_broadcast_) { |
| if (axis_ != -1) { |
| // Get axis from an explicit axis argument. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), |
| 0U, |
| "Args axis and axis_str cannot be used simultaneously."); |
| } else if (axis_str_.size()) { |
| // Get the axis index semantically. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), 1U, "Unsupported axis string", axis_str_); |
| const size_t semantic_axis_ = order_.find(axis_str_); |
| CAFFE_ENFORCE_NE( |
| semantic_axis_, |
| string::npos, |
| "Unrecognizable axis string ", |
| axis_str_, |
| " from order string ", |
| order_); |
| axis_ = semantic_axis_; |
| } else { |
| CAFFE_ENFORCE( |
| axis_ == -1 && axis_str_.empty(), |
| "Do not specify axis or axis_str if broadcast is not enabled."); |
| } |
| } |
| } |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<InputTypes>::call(this, Input(1)); |
| } |
| |
| template <typename T> |
| bool DoRunWithType() { |
| const auto& dC = Input(0); |
| const auto& A = Input(1); |
| const auto& B = Input(2); |
| |
| vector<int> A_dims; |
| vector<int> B_dims; |
| if (legacy_broadcast_) { |
| if (B.numel() == 1) { |
| A_dims = {static_cast<int>(A.numel())}; |
| B_dims = {1}; |
| } else { |
| size_t pre, n, post; |
| std::tie(pre, n, post) = |
| elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_); |
| A_dims = { |
| static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)}; |
| B_dims = {static_cast<int>(n), 1}; |
| } |
| } else { |
| std::copy( |
| A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims)); |
| std::copy( |
| B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims)); |
| } |
| const typename OutputTypeMap::template type<T>* C_data = nullptr; |
| if (InputSize() == 4) { |
| const auto& C = Input(3); |
| C_data = C.template data<typename OutputTypeMap::template type<T>>(); |
| } |
| const auto* dC_data = |
| dC.template data<typename GradientTypeMap::template type<T>>(); |
| const T* A_data = A.template data<T>(); |
| const T* B_data = B.template data<T>(); |
| auto* dA = Output( |
| 0, A.sizes(), at::dtype<typename GradientTypeMap::template type<T>>()); |
| auto* dB = Output( |
| 1, B.sizes(), at::dtype<typename GradientTypeMap::template type<T>>()); |
| auto* dA_data = |
| dA->template mutable_data<typename GradientTypeMap::template type<T>>(); |
| auto* dB_data = |
| dB->template mutable_data<typename GradientTypeMap::template type<T>>(); |
| return functor_.Backward( |
| A_dims, |
| B_dims, |
| dC_data, |
| A_data, |
| B_data, |
| C_data, |
| dA_data, |
| dB_data, |
| &context_); |
| } |
| |
| private: |
| const bool legacy_broadcast_; |
| int axis_; |
| const std::string axis_str_; |
| const std::string order_; |
| |
| Functor functor_; |
| }; |
| |
| template <class Functor> |
| struct BinaryFunctorWithDefaultCtor { |
| explicit BinaryFunctorWithDefaultCtor(OperatorBase& /* op */) {} |
| |
| template <typename TIn, typename TOut, class Context> |
| bool Forward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TIn* A_data, |
| const TIn* B_data, |
| TOut* C_data, |
| Context* context) const { |
| return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context); |
| } |
| |
| template <typename TGrad, typename TIn, typename TOut, class Context> |
| bool Backward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TGrad* dC_data, |
| const TIn* A_data, |
| const TIn* B_data, |
| const TOut* C_data, |
| TGrad* dA_data, |
| TGrad* dB_data, |
| Context* context) const { |
| return functor.Backward( |
| A_dims, |
| B_dims, |
| dC_data, |
| A_data, |
| B_data, |
| C_data, |
| dA_data, |
| dB_data, |
| context); |
| } |
| |
| Functor functor{}; |
| }; |
| |
| template <class Functor> |
| struct BinaryFunctorWithBroadcastOptionsCtor { |
| explicit BinaryFunctorWithBroadcastOptionsCtor(OperatorBase& op) |
| : functor{op.GetSingleArgument<bool>("allow_broadcast_fastpath", false)} {} |
| |
| template <typename TIn, typename TOut, class Context> |
| bool Forward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TIn* A_data, |
| const TIn* B_data, |
| TOut* C_data, |
| Context* context) const { |
| return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context); |
| } |
| |
| template <typename TGrad, typename TIn, typename TOut, class Context> |
| bool Backward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TGrad* dC_data, |
| const TIn* A_data, |
| const TIn* B_data, |
| const TOut* C_data, |
| TGrad* dA_data, |
| TGrad* dB_data, |
| Context* context) const { |
| return functor.Backward( |
| A_dims, |
| B_dims, |
| dC_data, |
| A_data, |
| B_data, |
| C_data, |
| dA_data, |
| dB_data, |
| context); |
| } |
| |
| Functor functor; |
| }; |
| |
| // BinaryElementwiseOp is a wrapper around BinaryElementwiseWithArgsOp, with the |
| // difference that it takes a functor with default constructor, e.g. that does |
| // not need to take into consideration any arguments during operator creation. |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class TypeMap = SameTypeAsInput> |
| using BinaryElementwiseOp = BinaryElementwiseWithArgsOp< |
| InputTypes, |
| Context, |
| BinaryFunctorWithDefaultCtor<Functor>, |
| TypeMap>; |
| |
| // BinaryElementwiseGradientOp is a wrapper around |
| // BinaryElementwiseGradientWithArgsOp, with the difference that it takes a |
| // functor with default constructor, e.g. that does not need to take into |
| // consideration any arguments during operator creation. |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput, |
| class GradientTypeMap = SameTypeAsInput> |
| using BinaryElementwiseGradientOp = BinaryElementwiseWithArgsGradientOp< |
| InputTypes, |
| Context, |
| BinaryFunctorWithDefaultCtor<Functor>, |
| OutputTypeMap, |
| GradientTypeMap>; |
| |
| // BinaryElementwiseBroadcastOp is a wrapper around BinaryElementwiseWithArgsOp, |
| // with the difference that it takes a functor with a constructor that accepts |
| // broadcast-related arguments (just a single boolean for whether broadcast |
| // fastpaths are allowed at the time this comment was written). |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class TypeMap = SameTypeAsInput> |
| using BinaryElementwiseBroadcastOp = BinaryElementwiseWithArgsOp< |
| InputTypes, |
| Context, |
| BinaryFunctorWithBroadcastOptionsCtor<Functor>, |
| TypeMap>; |
| |
| // BinaryElementwiseGradientBroadcastOp is a wrapper around |
| // BinaryElementwiseWithArgsGradientOp, with the difference that it takes a |
| // functor with a constructor that accepts broadcast-related arguments (just a |
| // single boolean for whether broadcast fastpaths are allowed at the time this |
| // comment was written). |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class OutputTypeMap = SameTypeAsInput, |
| class GradientTypeMap = SameTypeAsInput> |
| using BinaryElementwiseGradientBroadcastOp = BinaryElementwiseWithArgsGradientOp< |
| InputTypes, |
| Context, |
| BinaryFunctorWithBroadcastOptionsCtor<Functor>, |
| OutputTypeMap, |
| GradientTypeMap>; |
| |
| // Forward-only Unary Functors. |
| template <class Context> |
| struct NotFunctor { |
| bool operator()(const int N, const bool* X, bool* Y, Context* context) const { |
| math::Not(N, X, Y, context); |
| return true; |
| } |
| }; |
| |
| template <class Context> |
| struct SignFunctor { |
| template <typename T> |
| bool operator()(const int N, const T* X, T* Y, Context* context) const { |
| math::Sign(N, X, Y, context); |
| return true; |
| } |
| }; |
| |
| // Forward-only Binary Functors. |
| #define C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(FunctorName) \ |
| template <class Context> \ |
| struct FunctorName##Functor { \ |
| template <typename TIn, typename TOut> \ |
| bool Forward( \ |
| const std::vector<int>& A_dims, \ |
| const std::vector<int>& B_dims, \ |
| const TIn* A, \ |
| const TIn* B, \ |
| TOut* C, \ |
| Context* context) const { \ |
| math::FunctorName( \ |
| A_dims.size(), \ |
| A_dims.data(), \ |
| B_dims.size(), \ |
| B_dims.data(), \ |
| A, \ |
| B, \ |
| C, \ |
| context); \ |
| return true; \ |
| } \ |
| }; |
| |
| // Compare functors. |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(EQ); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(NE); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(LT); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(LE); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(GT); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(GE); |
| |
| // Logical functors. |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(And); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(Or); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(Xor); |
| |
| // Bitwise functors. |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseAnd); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseOr); |
| C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR(BitwiseXor); |
| |
| #undef C10_DECLARE_FORWARD_ONLY_BINARY_FUNCTOR |
| |
| namespace SRLHelper { |
| |
| template <typename T> |
| void sum2one(const T* a, T* y, size_t n); |
| |
| template <typename T> |
| void RunWithBroadcastFront(const T* a, T* y, size_t pre, size_t n, CPUContext*); |
| |
| template <typename T> |
| void RunWithBroadcastBack(const T* a, T* y, size_t post, size_t n, CPUContext*); |
| |
| template <typename T> |
| void RunWithBroadcast2( |
| const T* a, |
| T* y, |
| size_t pre, |
| size_t n, |
| size_t post, |
| CPUContext*); |
| |
| } // namespace SRLHelper |
| |
| // Sum reduction operator that is used for computing the gradient in cases |
| // where the forward op is in broadcast mode. |
| template <class Context> |
| class SumReduceLikeOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit SumReduceLikeOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(int, "axis", axis_, -1), |
| OP_SINGLE_ARG(string, "axis_str", axis_str_, ""), |
| OP_SINGLE_ARG(string, "order", order_, "NCHW") { |
| if (axis_ != -1) { |
| // Get axis from an explicit axis argument. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), |
| 0U, |
| "Args axis and axis_str cannot be used simultaneously."); |
| } else if (axis_str_.size()) { |
| // Get the axis index semantically. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), 1U, "Unsupported axis string", axis_str_); |
| size_t semantic_axis = order_.find(axis_str_); |
| CAFFE_ENFORCE_NE( |
| semantic_axis, |
| string::npos, |
| "Unrecognizable axis string ", |
| axis_str_, |
| " from order string ", |
| order_); |
| axis_ = semantic_axis; |
| } |
| } |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0)); |
| } |
| |
| template <typename T> |
| bool DoRunWithType(); |
| |
| private: |
| int axis_; |
| string axis_str_; |
| string order_; |
| Tensor ones_{Context::GetDeviceType()}; |
| Tensor sum_buffer_{Context::GetDeviceType()}; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ |