| #pragma once |
| |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| |
| template <typename Context> |
| void storm_update( |
| const int N, |
| const float* paramIn, |
| const float* momentIn, |
| const float* gradSqSumIn, |
| const float* gradIn, |
| const float* lr, |
| float* paramOut, |
| float* momentOut, |
| float* gradSqSumOut, |
| const float momentum, |
| const float beta, |
| Context* /*context*/) { |
| float gradSqSumTmp = 0.0; |
| for (const auto i : c10::irange(N)) { |
| const float gi = gradIn[i]; |
| gradSqSumTmp += gi * gi; |
| } |
| gradSqSumOut[0] = gradSqSumIn[0] + gradSqSumTmp; |
| |
| const float nlr = lr[0] * std::pow(beta + gradSqSumOut[0], -1.0 / 3.0); |
| const float alpha = momentum * nlr * nlr; |
| for (const auto i : c10::irange(N)) { |
| const float gi = gradIn[i]; |
| const float mi = momentIn[i]; |
| float new_mi = momentOut[i] = gi + (1.0 - alpha) * (mi - gi); |
| paramOut[i] = paramIn[i] + nlr * new_mi; |
| } |
| } |
| |
| template <class Context> |
| class StormOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| StormOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| OP_SINGLE_ARG(float, "momentum", momentum_, 10.0), |
| OP_SINGLE_ARG(float, "beta", beta_, 0.1) {} |
| |
| bool RunOnDevice() override { |
| // Enforce shapes |
| CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(PARAM).numel()); |
| CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(MOMENT).numel()); |
| CAFFE_ENFORCE_EQ(Input(GRADSQSUM).numel(), 1); |
| CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); |
| |
| // Resize [potentially] out-of-place blobs |
| Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); |
| Output(OUTPUT_MOMENT)->ResizeLike(Input(MOMENT)); |
| Output(OUTPUT_GRAGSQSUM)->ResizeLike(Input(GRADSQSUM)); |
| |
| storm_update<Context>( |
| Input(GRAD).numel(), |
| Input(PARAM).template data<float>(), |
| Input(MOMENT).template data<float>(), |
| Input(GRADSQSUM).template data<float>(), |
| Input(GRAD).template data<float>(), |
| Input(LR).template data<float>(), |
| Output(OUTPUT_PARAM)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT)->template mutable_data<float>(), |
| Output(OUTPUT_GRAGSQSUM)->template mutable_data<float>(), |
| momentum_, |
| beta_, |
| &context_); |
| return true; |
| } |
| |
| protected: |
| const float momentum_; |
| const float beta_; |
| INPUT_TAGS(PARAM, MOMENT, GRADSQSUM, GRAD, LR); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT, OUTPUT_GRAGSQSUM); |
| }; |
| |
| template <class Context> |
| class SparseStormOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| SparseStormOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| OP_SINGLE_ARG(float, "momentum", momentum_, 10.0), |
| OP_SINGLE_ARG(float, "beta", beta_, 0.1) {} |
| |
| bool RunOnDevice() override { |
| // Enforce shapes |
| CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT).numel()); |
| CAFFE_ENFORCE_EQ(Input(GRADSQSUM).numel(), 1); |
| CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).size_from_dim(1), |
| Input(GRAD).size_from_dim(Input(INDICES).dim())); |
| |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, Input(INDICES)); |
| } |
| |
| template <typename SIndex> |
| bool DoRunWithType() { |
| const auto* paramIn = Input(PARAM).template data<float>(); |
| const auto* momentIn = Input(MOMENT).template data<float>(); |
| const auto* gradSqSumIn = Input(GRADSQSUM).template data<float>(); |
| const auto* gradIn = Input(GRAD).template data<float>(); |
| const auto* indices = Input(INDICES).template data<SIndex>(); |
| const auto* lr = Input(LR).template data<float>(); |
| auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>(); |
| auto* momentOut = Output(OUTPUT_MOMENT)->template mutable_data<float>(); |
| auto* gradSqSumOut = |
| Output(OUTPUT_GRAGSQSUM)->template mutable_data<float>(); |
| |
| auto n = Input(INDICES).numel(); |
| if (n == 0) { |
| return true; |
| } |
| |
| float gradSqSumTmp = 0.0; |
| for (const auto i : c10::irange(Input(GRAD).numel())) { |
| const float gi = gradIn[i]; |
| gradSqSumTmp += gi * gi; |
| } |
| gradSqSumOut[0] = gradSqSumIn[0] + gradSqSumTmp; |
| |
| const float nlr = lr[0] * std::pow(beta_ + gradSqSumOut[0], -1.0 / 3.0); |
| const float alpha = momentum_ * nlr * nlr; |
| const auto block_size = Input(GRAD).numel() / n; |
| |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| if (block_size == 1) { |
| const float gi = gradIn[i]; |
| const float mi = momentIn[idx]; |
| float new_mi = momentOut[idx] = gi + (1.0 - alpha) * (mi - gi); |
| paramOut[idx] = paramIn[idx] + nlr * new_mi; |
| } else { |
| auto offsetI = i * block_size; |
| auto offsetIdx = idx * block_size; |
| |
| #ifndef NDEBUG |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| block_size + offsetIdx, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| idx, |
| " for input i:", |
| i, |
| " and block size:", |
| block_size); |
| CAFFE_ENFORCE_GE( |
| Input(GRAD).numel(), |
| block_size + offsetI, |
| this->debug_def().input(GRAD), |
| ", out of bound idx, idx:", |
| idx, |
| " for input i:", |
| i); |
| #endif |
| |
| for (const auto j : c10::irange(block_size)) { |
| const float gi = gradIn[offsetI + j]; |
| const float mi = momentIn[offsetIdx + j]; |
| float new_mi = momentOut[offsetIdx + j] = |
| gi + (1.0 - alpha) * (mi - gi); |
| paramOut[offsetIdx + j] = paramIn[offsetIdx + j] + nlr * new_mi; |
| } |
| } |
| } |
| |
| return true; |
| } |
| |
| protected: |
| const float momentum_; |
| const float beta_; |
| INPUT_TAGS(PARAM, MOMENT, GRADSQSUM, GRAD, INDICES, LR); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT, OUTPUT_GRAGSQSUM); |
| }; |
| } // namespace caffe2 |