| #ifndef CAFFE2_OPERATORS_RELU_N_OP_H_ |
| #define CAFFE2_OPERATORS_RELU_N_OP_H_ |
| |
| #include <vector> |
| |
| #include "caffe2/operators/elementwise_ops.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| struct ReluNFunctor { |
| explicit ReluNFunctor(OperatorBase& op) |
| : n(op.GetSingleArgument<float>("n", 6.0f)) { |
| CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); |
| } |
| |
| template <typename T> |
| bool operator()(const int N, const T* X, T* Y, Context* context) const; |
| |
| const float n; |
| }; |
| |
| template <class Context> |
| struct ReluNGradientFunctor { |
| explicit ReluNGradientFunctor(OperatorBase& op) |
| : n(op.GetSingleArgument<float>("n", 6.0f)) { |
| CAFFE_ENFORCE_GT(n, 0, "n should be greater than 0"); |
| } |
| |
| template <typename T> |
| bool Forward( |
| const std::vector<int>& Y_dims, |
| const std::vector<int>& dY_dims, |
| const T* Y, |
| const T* dY, |
| T* dX, |
| Context* context) const; |
| |
| const float n; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_RELU_N_OP_H_ |