| #include "caffe2/core/operator.h" |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| |
| namespace { |
| |
| template <typename Context> |
| void AdadeltaUpdate( |
| int N, |
| const float* w, |
| const float* g, |
| const float* h, |
| const float* d, |
| const float epsilon, |
| const float decay, |
| const float* lr, |
| float* nw, |
| float* nh, |
| float* nd, |
| Context* /*context*/) { |
| for (const auto i : c10::irange(N)) { |
| float gi = g[i]; |
| float di = d[i]; |
| float hi = nh[i] = decay * h[i] + (1.0f - decay) * gi * gi; |
| float ng = (std::sqrt(di + epsilon) / std::sqrt(hi + epsilon)) * gi; |
| nw[i] = w[i] + lr[0] * ng; |
| nd[i] = decay * di + (1.0f - decay) * ng * ng; |
| } |
| } |
| |
| } // namespace |
| |
| template <class Context> |
| class AdadeltaOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| AdadeltaOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f), |
| OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {} |
| |
| bool RunOnDevice() override { |
| CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_GRAD).numel()); |
| CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_DELTA).numel()); |
| CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel()); |
| CAFFE_ENFORCE_GE(epsilon_, 0.0f); |
| CAFFE_ENFORCE_GT(decay_, 0.0f); |
| CAFFE_ENFORCE_LT(decay_, 1.0f); |
| |
| Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); |
| Output(OUTPUT_MOMENT_GRAD)->ResizeLike(Input(MOMENT_GRAD)); |
| Output(OUTPUT_MOMENT_DELTA)->ResizeLike(Input(MOMENT_DELTA)); |
| AdadeltaUpdate<Context>( |
| Input(GRAD).numel(), |
| Input(PARAM).template data<float>(), |
| Input(GRAD).template data<float>(), |
| Input(MOMENT_GRAD).template data<float>(), |
| Input(MOMENT_DELTA).template data<float>(), |
| epsilon_, |
| decay_, |
| Input(LR).template data<float>(), |
| Output(OUTPUT_PARAM)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>(), |
| &context_); |
| return true; |
| } |
| |
| protected: |
| const float epsilon_; |
| const float decay_; |
| INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, GRAD, LR); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA); |
| }; |
| |
| template <class Context> |
| class SparseAdadeltaOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| SparseAdadeltaOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f), |
| OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {} |
| |
| bool RunOnDevice() override { |
| // Enforce shapes |
| CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_GRAD).numel()); |
| CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_DELTA).numel()); |
| CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).size_from_dim(1), |
| Input(GRAD).size_from_dim(Input(INDICES).dim())); |
| |
| // Enforce domain constraints for attributes |
| CAFFE_ENFORCE_GE(epsilon_, 0.0f); |
| CAFFE_ENFORCE_GT(decay_, 0.0f); |
| CAFFE_ENFORCE_LT(decay_, 1.0f); |
| |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, Input(INDICES)); |
| } |
| |
| template <typename SIndex> |
| bool DoRunWithType() { |
| const auto* lr = Input(LR).template data<float>(); |
| const auto* indices = Input(INDICES).template data<SIndex>(); |
| const auto* gradIn = Input(GRAD).template data<float>(); |
| const auto* paramIn = Input(PARAM).template data<float>(); |
| const auto* momentIn = Input(MOMENT_GRAD).template data<float>(); |
| const auto* momentDeltaIn = Input(MOMENT_DELTA).template data<float>(); |
| auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>(); |
| auto* momentOut = |
| Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>(); |
| auto* momentDeltaOut = |
| Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>(); |
| |
| auto n = Input(INDICES).numel(); |
| if (n == 0) { |
| return true; |
| } |
| |
| auto block_size = Input(GRAD).numel() / n; |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| if (block_size == 1) { |
| float gi = gradIn[i]; |
| float di = momentDeltaIn[idx]; |
| float hi = momentOut[idx] = |
| decay_ * momentIn[idx] + (1.0f - decay_) * gi * gi; |
| float ng = (std::sqrt(di + epsilon_) / std::sqrt(hi + epsilon_)) * gi; |
| paramOut[idx] = paramIn[idx] + lr[0] * ng; |
| momentDeltaOut[idx] = decay_ * di + (1.0f - decay_) * ng * ng; |
| } else { |
| auto offsetI = i * block_size; |
| auto offsetIdx = idx * block_size; |
| |
| #ifndef NDEBUG |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| block_size + offsetIdx, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| idx, |
| " for input i:", |
| i, |
| " and block size:", |
| block_size); |
| CAFFE_ENFORCE_GE( |
| Input(GRAD).numel(), |
| block_size + offsetI, |
| this->debug_def().input(GRAD), |
| ", out of bound idx, idx:", |
| idx, |
| " for input i:", |
| i); |
| #endif |
| AdadeltaUpdate( |
| block_size, |
| paramIn + offsetIdx, |
| gradIn + offsetI, |
| momentIn + offsetIdx, |
| momentDeltaIn + offsetIdx, |
| epsilon_, |
| decay_, |
| lr, |
| paramOut + offsetIdx, |
| momentOut + offsetIdx, |
| momentDeltaOut + offsetIdx, |
| &context_); |
| } |
| } |
| return true; |
| } |
| |
| protected: |
| const float epsilon_; |
| const float decay_; |
| INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, INDICES, GRAD, LR); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA); |
| }; |
| |
| } // namespace caffe2 |