| #pragma once |
| |
| #include "caffe2/core/operator.h" |
| #include "caffe2/perfkernels/adagrad.h" |
| #if defined(USE_FBGEMM) && !defined(__NVCC__) |
| #include "fbgemm/FbgemmEmbedding.h" |
| #endif |
| |
| namespace caffe2 { |
| |
| template <typename Context> |
| void adagrad_update( |
| int N, |
| const float* w, |
| const float* g, |
| const float* h, |
| float* nw, |
| float* nh, |
| float epsilon, |
| float decay, |
| const float* lr, |
| Context* /*context*/, |
| float weight_decay = 0.f) { |
| return adagrad_update( |
| N, w, g, h, nw, nh, epsilon, decay, lr[0], weight_decay); |
| } |
| |
| template <typename Context> |
| void adagrad_update_output_effective_lr( |
| int N, |
| const float* paramIn, |
| const float* gradIn, |
| const float* momentIn, |
| float* paramOut, |
| float* momentOut, |
| float* effectiveLROut, |
| float epsilon, |
| float decay, |
| const float* lr, |
| Context* /*context*/, |
| float weight_decay = 0.f) { |
| for (const auto i : c10::irange(N)) { |
| float grad = std::fma(weight_decay, paramIn[i], gradIn[i]); |
| float moment = momentOut[i] = decay * momentIn[i] + grad * grad; |
| float effective_lr = effectiveLROut[i] = |
| lr[0] / (std::sqrt(moment) + epsilon); |
| paramOut[i] = paramIn[i] + effective_lr * grad; |
| } |
| } |
| |
| template <typename Context> |
| void adagrad_update_output_effective_lr_and_update( |
| int N, |
| const float* paramIn, |
| const float* gradIn, |
| const float* momentIn, |
| float* paramOut, |
| float* momentOut, |
| float* effectiveLROut, |
| float* updateOut, |
| float epsilon, |
| float decay, |
| const float* lr, |
| Context* /*context*/, |
| float weight_decay = 0.f) { |
| for (const auto i : c10::irange(N)) { |
| float grad = std::fma(weight_decay, paramIn[i], gradIn[i]); |
| float moment = momentOut[i] = decay * momentIn[i] + grad * grad; |
| float effective_lr = effectiveLROut[i] = |
| lr[0] / (std::sqrt(moment) + epsilon); |
| float update = updateOut[i] = effective_lr * grad; |
| paramOut[i] = paramIn[i] + update; |
| } |
| } |
| |
| template <class Context> |
| class AdagradOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| AdagradOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)), |
| decay_(this->template GetSingleArgument<float>("decay", 1.0f)), |
| weight_decay_( |
| this->template GetSingleArgument<float>("weight_decay", 0.f)) { |
| VLOG(1) << "gradient optimization operator in use: " |
| << "AdagradOp" |
| << " weight_decay_=" << weight_decay_; |
| } |
| |
| bool RunOnDevice() override { |
| CAFFE_ENFORCE_EQ( |
| Input(GRAD).numel(), |
| Input(MOMENT_1).numel(), |
| "PARAM size: ", |
| Input(PARAM).numel(), |
| ", GRAD size: ", |
| Input(GRAD).numel(), |
| ", MOMENT_1 size: ", |
| Input(MOMENT_1).numel(), |
| ", LR size: ", |
| Input(LR).numel()); |
| |
| CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(PARAM).numel()); |
| Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM)); |
| Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1)); |
| if (OutputSize() == 2) { |
| adagrad_update<Context>( |
| Input(GRAD).numel(), |
| Input(PARAM).template data<float>(), |
| Input(GRAD).template data<float>(), |
| Input(MOMENT_1).template data<float>(), |
| Output(OUTPUT_PARAM)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT_1)->template mutable_data<float>(), |
| epsilon_, |
| decay_, |
| Input(LR).template data<float>(), |
| &context_, |
| weight_decay_); |
| } else if (OutputSize() == 3) { |
| Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD)); |
| adagrad_update_output_effective_lr<Context>( |
| Input(GRAD).numel(), |
| Input(PARAM).template data<float>(), |
| Input(GRAD).template data<float>(), |
| Input(MOMENT_1).template data<float>(), |
| Output(OUTPUT_PARAM)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT_1)->template mutable_data<float>(), |
| Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<float>(), |
| epsilon_, |
| decay_, |
| Input(LR).template data<float>(), |
| &context_, |
| weight_decay_); |
| } else { |
| Output(OUTPUT_EFFECTIVE_LR)->ResizeLike(Input(GRAD)); |
| Output(OUTPUT_UPDATE)->ResizeLike(Input(GRAD)); |
| adagrad_update_output_effective_lr_and_update<Context>( |
| Input(GRAD).numel(), |
| Input(PARAM).template data<float>(), |
| Input(GRAD).template data<float>(), |
| Input(MOMENT_1).template data<float>(), |
| Output(OUTPUT_PARAM)->template mutable_data<float>(), |
| Output(OUTPUT_MOMENT_1)->template mutable_data<float>(), |
| Output(OUTPUT_EFFECTIVE_LR)->template mutable_data<float>(), |
| Output(OUTPUT_UPDATE)->template mutable_data<float>(), |
| epsilon_, |
| decay_, |
| Input(LR).template data<float>(), |
| &context_, |
| weight_decay_); |
| } |
| |
| return true; |
| } |
| |
| protected: |
| float epsilon_; |
| float decay_; |
| float weight_decay_; |
| INPUT_TAGS(PARAM, MOMENT_1, GRAD, LR); |
| OUTPUT_TAGS( |
| OUTPUT_PARAM, |
| OUTPUT_MOMENT_1, |
| OUTPUT_EFFECTIVE_LR, |
| OUTPUT_UPDATE); |
| }; |
| |
| class SparseAdagradOp final : public Operator<CPUContext> { |
| public: |
| SparseAdagradOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)), |
| weight_decay_( |
| this->template GetSingleArgument<float>("weight_decay", 0.f)) { |
| VLOG(1) << "gradient optimization operator in use: " |
| << "SparseAdagradOp" |
| << " weight_decay_=" << weight_decay_; |
| const float decay = this->template GetSingleArgument<float>("decay", 1.0); |
| CAFFE_ENFORCE_EQ( |
| decay, 1.0, "Decay is not supported for SparseSimdAdagradOp"); |
| } |
| |
| bool RunOnDevice() override { |
| // Enforce shapes |
| // input(embedding/momentum) == outputs(embedding/momentum) |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).numel(), |
| Input(MOMENT_1).numel(), |
| "Input Param size: ", |
| Input(PARAM).numel(), |
| " Input Moment size: ", |
| Input(MOMENT_1).numel()); |
| CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).size_from_dim(1), |
| Input(GRAD).size_from_dim(Input(INDICES).dim())); |
| |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, Input(INDICES)); |
| } |
| |
| template <typename SIndex> |
| bool DoRunWithType() { |
| const auto* lr = Input(LR).template data<float>(); |
| |
| auto n = Input(INDICES).numel(); |
| |
| const auto* indices = Input(INDICES).template data<SIndex>(); |
| const auto* gradIn = Input(GRAD).template data<float>(); |
| auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>(); |
| auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<float>(); |
| |
| if (n == 0) { |
| return true; |
| } |
| auto block_size = Input(GRAD).numel() / n; |
| |
| // input(grad) is compatible with size of indexes |
| CAFFE_ENFORCE_EQ( |
| Input(GRAD).numel() % n, |
| 0, |
| "Incorrect gradient size:", |
| Input(GRAD).numel(), |
| " size of indexes:", |
| n); |
| |
| #if defined(USE_FBGEMM) && !defined(__NVCC__) |
| VLOG(1) << "using fbgemm::GenerateSparseAdaGrad in SparseAdagradOp"; |
| |
| if (block_size != last_block_size_) { |
| last_block_size_ = block_size; |
| if (std::is_same<SIndex, std::int32_t>::value) { |
| kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>( |
| block_size, |
| /*rowwise=*/false, |
| /*prefetch=*/16, |
| weight_decay_ != 0.0f); |
| } else { |
| CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value)); |
| kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>( |
| block_size, |
| /*rowwise=*/false, |
| /*prefetch=*/16, |
| weight_decay_ != 0.0f); |
| } |
| } |
| |
| int num_rows_processed; |
| if (std::is_same<SIndex, std::int32_t>::value) { |
| num_rows_processed = kernel_i32_( |
| n, |
| Input(PARAM).numel(), |
| paramOut, |
| gradIn, |
| momentOut, |
| reinterpret_cast<const std::int32_t*>(indices), |
| epsilon_, |
| lr[0], |
| weight_decay_, |
| /*counter=*/nullptr, |
| /*counter_halflife=*/0); |
| } else { |
| num_rows_processed = kernel_i64_( |
| n, |
| Input(PARAM).numel(), |
| paramOut, |
| gradIn, |
| momentOut, |
| reinterpret_cast<const std::int64_t*>(indices), |
| epsilon_, |
| lr[0], |
| weight_decay_, |
| /*counter=*/nullptr, |
| /*counter_halflife=*/0); |
| } |
| if (num_rows_processed < n) { |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| (indices[num_rows_processed] + 1) * block_size, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| indices[num_rows_processed], |
| " for input i:", |
| num_rows_processed, |
| " and block_size:", |
| block_size, |
| " max size:", |
| Input(PARAM).numel()); |
| return false; |
| } else { |
| return true; |
| } |
| #endif |
| |
| VLOG(1) |
| << "using internal::adagrad_update_prefetch_inlined in SparseAdagradOp"; |
| |
| const auto* paramIn = Input(PARAM).template data<float>(); |
| const auto* momentIn = Input(MOMENT_1).template data<float>(); |
| |
| std::vector<float> grad(block_size); |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| auto offsetI = i * block_size; |
| auto offsetIdx = idx * block_size; |
| |
| // Enforce: |
| // access within range |
| // gradient access within range |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| block_size + offsetIdx, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| idx, |
| " for input i:", |
| i, |
| " and block size:", |
| block_size, |
| " max size:", |
| Input(PARAM).numel()); |
| |
| if (block_size == 1) { |
| float gi = std::fma(weight_decay_, paramIn[idx], gradIn[i]); |
| float hi = momentOut[idx] = momentIn[idx] + gi * gi; |
| paramOut[idx] = paramIn[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_); |
| } else { |
| // prefetching |
| const int prefdist_T0 = 16; |
| int i_pref = (i < n - prefdist_T0) ? i + prefdist_T0 : i; |
| std::size_t idx_pref = indices[i_pref]; |
| |
| internal::adagrad_update_prefetch_inlined( |
| block_size, |
| paramIn + offsetIdx, |
| ¶mIn[idx_pref * block_size], |
| gradIn + offsetI, |
| momentIn + offsetIdx, |
| &momentIn[idx_pref * block_size], |
| paramOut + offsetIdx, |
| ¶mOut[idx_pref * block_size], |
| momentOut + offsetIdx, |
| &momentOut[idx_pref * block_size], |
| epsilon_, |
| lr[0], |
| weight_decay_); |
| } |
| } |
| return true; |
| } |
| |
| protected: |
| float epsilon_; |
| const float weight_decay_; |
| #if defined(USE_FBGEMM) && !defined(__NVCC__) |
| fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_; |
| fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_; |
| std::int64_t last_block_size_{-1}; |
| #endif |
| |
| INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1); |
| }; |
| |
| template <class Context> |
| class RowWiseSparseAdagradOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| RowWiseSparseAdagradOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)), |
| weight_decay_( |
| this->template GetSingleArgument<float>("weight_decay", 0.f)), |
| counter_halflife_( |
| this->template GetSingleArgument<int64_t>("counter_halflife", -1)) { |
| VLOG(1) << "gradient optimization operator in use: " |
| << "RowWiseSparseAdagradOp" |
| << " weight_decay_=" << weight_decay_ |
| << " counter_halflife=" << counter_halflife_; |
| } |
| |
| bool RunOnDevice() override { |
| // Enforce shapes |
| CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel()); |
| CAFFE_ENFORCE_EQ(Input(LR).numel(), 1); |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).size_from_dim(1), |
| Input(GRAD).size_from_dim(Input(INDICES).dim())); |
| |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, Input(INDICES)); |
| } |
| |
| template <typename SIndex> |
| bool DoRunWithType() { |
| const auto* lr = Input(LR).template data<float>(); |
| auto* param = Output(OUTPUT_PARAM)->template mutable_data<float>(); |
| auto* moment = Output(OUTPUT_MOMENT_1)->template mutable_data<float>(); |
| |
| const auto* indices = Input(INDICES).template data<SIndex>(); |
| const auto* gradIn = Input(GRAD).template data<float>(); |
| const auto* count = counter_halflife_ == -1 |
| ? nullptr |
| : Input(COUNTER).template data<double>(); |
| |
| auto n = Input(INDICES).numel(); |
| if (n == 0) { |
| return true; |
| } |
| |
| auto block_size = Input(GRAD).numel() / n; |
| |
| // Enforce: |
| // Input(embedding/momentum) == outputs(embedding/momentum) |
| CAFFE_ENFORCE_EQ( |
| Input(PARAM).numel() / block_size, |
| Input(MOMENT_1).numel(), |
| "Input Param size: ", |
| Input(PARAM).numel(), |
| " Block size: ", |
| block_size, |
| " Input Moment size: ", |
| Input(MOMENT_1).numel()); |
| |
| // input(grad) is compatible with size of indexes |
| CAFFE_ENFORCE_EQ( |
| Input(GRAD).numel() % n, |
| 0, |
| "Incorrect gradient size:", |
| Input(GRAD).numel(), |
| " size of indexes:", |
| n); |
| |
| #if defined(USE_FBGEMM) && !defined(__NVCC__) |
| VLOG(1) << "using fbgemm::GenerateSparseAdaGrad in RowWiseSparseAdagradOp"; |
| |
| if (block_size != last_block_size_) { |
| last_block_size_ = block_size; |
| if (std::is_same<SIndex, std::int32_t>::value) { |
| kernel_i32_ = fbgemm::GenerateSparseAdaGrad<std::int32_t>( |
| block_size, |
| /*rowwise=*/true, |
| /*prefetch=*/16, |
| weight_decay_ != 0.0f); |
| } else { |
| CAFFE_ENFORCE((std::is_same<SIndex, std::int64_t>::value)); |
| kernel_i64_ = fbgemm::GenerateSparseAdaGrad<std::int64_t>( |
| block_size, |
| /*rowwise=*/true, |
| /*prefetch=*/16, |
| weight_decay_ != 0.0f); |
| } |
| } |
| |
| int num_rows_processed; |
| if (std::is_same<SIndex, std::int32_t>::value) { |
| num_rows_processed = kernel_i32_( |
| n, |
| Input(PARAM).numel(), |
| param, |
| gradIn, |
| moment, |
| reinterpret_cast<const std::int32_t*>(indices), |
| epsilon_, |
| lr[0], |
| weight_decay_, |
| (counter_halflife_ > 0) ? count : nullptr, |
| counter_halflife_); |
| } else { |
| num_rows_processed = kernel_i64_( |
| n, |
| Input(PARAM).numel(), |
| param, |
| gradIn, |
| moment, |
| reinterpret_cast<const std::int64_t*>(indices), |
| epsilon_, |
| lr[0], |
| weight_decay_, |
| (counter_halflife_ > 0) ? count : nullptr, |
| counter_halflife_); |
| } |
| |
| if (num_rows_processed < n) { |
| // Enforce: |
| // access within range |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| (indices[num_rows_processed] + 1) * block_size, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| indices[num_rows_processed], |
| " for input i:", |
| num_rows_processed, |
| " and block size:", |
| block_size, |
| " max size:", |
| Input(PARAM).numel()); |
| return false; |
| } else { |
| return true; |
| } |
| #else |
| VLOG(1) << "using plain adagrad updates in RowWiseSparseAdagradOp"; |
| |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| float freq = (counter_halflife_ > 0 && count[idx] > 0) |
| ? counter_halflife_ / count[idx] |
| : 1.0; |
| if (block_size == 1) { |
| float gi = std::fma(weight_decay_ * freq, param[idx], gradIn[i]); |
| float hi = moment[idx] = moment[idx] + gi * gi; |
| param[idx] = param[idx] + lr[0] * gi / (std::sqrt(hi) + epsilon_); |
| } else { |
| auto offsetI = i * block_size; |
| auto offsetIdx = idx * block_size; |
| |
| #ifndef NDEBUG |
| CAFFE_ENFORCE_GE( |
| Input(PARAM).numel(), |
| block_size + offsetIdx, |
| this->debug_def().input(PARAM), |
| ", out of bound, idx:", |
| idx, |
| " for input i:", |
| i, |
| " and block size:", |
| block_size); |
| CAFFE_ENFORCE_GE( |
| Input(GRAD).numel(), |
| block_size + offsetI, |
| this->debug_def().input(GRAD), |
| ", out of bound idx, idx:", |
| idx, |
| " for input i:", |
| i); |
| #endif |
| |
| float* w = param + offsetIdx; |
| const float* g = gradIn + offsetI; |
| float* h = moment + idx; |
| float hs = 0.; |
| for (const auto j : c10::irange(block_size)) { |
| float gj = std::fma(weight_decay_ * freq, w[j], g[j]); |
| hs += gj * gj; |
| } |
| float hi = h[0] = h[0] + hs / block_size; |
| float step = lr[0] / (std::sqrt(hi) + epsilon_); |
| for (const auto j : c10::irange(block_size)) { |
| float gj = std::fma(weight_decay_ * freq, w[j], g[j]); |
| w[j] = w[j] + gj * step; |
| } |
| } |
| } |
| return true; |
| #endif // !USE_FBGEMM |
| } |
| |
| protected: |
| float epsilon_; |
| const float weight_decay_; |
| const int64_t counter_halflife_; |
| #if defined(USE_FBGEMM) && !defined(__NVCC__) |
| fbgemm::SparseAdaGradSignature<std::int32_t>::Type kernel_i32_; |
| fbgemm::SparseAdaGradSignature<std::int64_t>::Type kernel_i64_; |
| std::int64_t last_block_size_{-1}; |
| #endif |
| |
| INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, COUNTER); |
| OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1); |
| }; |
| } // namespace caffe2 |