| #ifndef CAFFE2_OPERATORS_ELEMENTWISE_SUB_OP_H_ |
| #define CAFFE2_OPERATORS_ELEMENTWISE_SUB_OP_H_ |
| |
| #include <algorithm> |
| #include <functional> |
| #include <vector> |
| |
| #include "caffe2/operators/elementwise_ops.h" |
| #include "caffe2/operators/elementwise_ops_utils.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| struct SubFunctor { |
| template <typename TIn, typename TOut> |
| bool Forward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TIn* A, |
| const TIn* B, |
| TOut* C, |
| Context* context) const { |
| math::Sub( |
| A_dims.size(), |
| A_dims.data(), |
| B_dims.size(), |
| B_dims.data(), |
| A, |
| B, |
| C, |
| context); |
| return true; |
| } |
| |
| template <typename TGrad, typename TIn, typename TOut> |
| bool Backward( |
| const std::vector<int>& A_dims, |
| const std::vector<int>& B_dims, |
| const TGrad* dC, |
| const TIn* /* A */, |
| const TIn* /* B */, |
| const TOut* /* C */, |
| TGrad* dA, |
| TGrad* dB, |
| Context* context) const { |
| const std::vector<int> C_dims = |
| elementwise_ops_utils::ComputeBinaryBroadcastForwardDims( |
| A_dims, B_dims); |
| std::vector<int> A_back_dims; |
| std::vector<int> B_back_dims; |
| elementwise_ops_utils::ComputeBinaryBroadcastBackwardDims( |
| A_dims, B_dims, &A_back_dims, &B_back_dims); |
| math::ReduceSum( |
| C_dims.size(), |
| C_dims.data(), |
| A_back_dims.data(), |
| TGrad(1), |
| dC, |
| dA, |
| context, |
| true); |
| math::ReduceSum( |
| C_dims.size(), |
| C_dims.data(), |
| B_back_dims.data(), |
| TGrad(-1), |
| dC, |
| dB, |
| context, |
| true); |
| return true; |
| } |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_ELEMENTWISE_SUB_OP_H_ |