| #pragma once |
| |
| #include "caffe2/core/common_omp.h" |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| |
| template <typename Context> |
| void rmsprop_update( |
| int N, |
| const float* g, |
| const float* ms, |
| const float* mom, |
| float* ng, |
| float* nms, |
| float* nmom, |
| float decay, |
| float momentum, |
| float epsilon, |
| const float* lr, |
| Context* context); |
| |
| template <typename T, class Context> |
| class RmsPropOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| RmsPropOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| decay_(this->template GetSingleArgument<float>("decay", 0.9f)), |
| momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)), |
| epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {} |
| bool RunOnDevice() override { |
| CAFFE_ENFORCE(Input(LR).numel() == 1); |
| CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel()); |
| CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel()); |
| Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); |
| Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD)); |
| Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES)); |
| Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM)); |
| rmsprop_update<Context>( |
| Input(GRAD).numel(), |
| Input(GRAD).template data<T>(), |
| Input(MEAN_SQUARES).template data<T>(), |
| Input(MOMENTUM).template data<T>(), |
| Output(OUTPUT_GRAD)->template mutable_data<T>(), |
| Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(), |
| Output(OUTPUT_MOMENTUM)->template mutable_data<T>(), |
| decay_, |
| momentum_, |
| epsilon_, |
| Input(LR).template data<T>(), |
| &context_); |
| return true; |
| } |
| |
| protected: |
| T decay_{0.9}; |
| T momentum_{0.0}; |
| T epsilon_{1e-8}; |
| INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR); |
| OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM); |
| }; |
| } |