| #pragma once |
| |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| |
| class RowWiseCounterOp final : public Operator<CPUContext> { |
| public: |
| RowWiseCounterOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| counter_halflife_( |
| this->template GetSingleArgument<int64_t>("counter_halflife", -1)), |
| counter_neg_log_rho_(0.0) { |
| if (counter_halflife_ > 0) { |
| counter_neg_log_rho_ = std::log(2.0) / counter_halflife_; |
| } |
| } |
| |
| bool RunOnDevice() override { |
| CAFFE_ENFORCE_EQ(Input(PREV_ITER).numel(), Input(COUNTER).numel()); |
| CAFFE_ENFORCE_EQ(Input(ITER).numel(), 1); |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, Input(INDICES)); |
| } |
| |
| template <typename SIndex> |
| bool DoRunWithType() { |
| auto* prev_iter = |
| Output(OUTPUT_PREV_ITER)->template mutable_data<int64_t>(); |
| auto* counter = Output(OUTPUT_COUNTER)->template mutable_data<double>(); |
| |
| const int64_t curr_iter = Input(ITER).template data<int64_t>()[0]; |
| const auto* indices = Input(INDICES).template data<SIndex>(); |
| |
| auto n = Input(INDICES).numel(); |
| if (n == 0) { |
| return true; |
| } |
| if (counter_halflife_ <= 0) { |
| return true; |
| } |
| |
| for (const auto i : c10::irange(n)) { |
| const std::size_t idx = indices[i]; |
| CAFFE_ENFORCE_GE( |
| Input(COUNTER).numel(), |
| idx, |
| this->debug_def().input(COUNTER), |
| ", out of bound, idx:", |
| idx, |
| " for input i:", |
| i, |
| " max size:", |
| Input(COUNTER).numel()); |
| const int64_t iter_delta = |
| std::max<int64_t>(0, curr_iter - prev_iter[idx]); |
| |
| counter[idx] = |
| 1.0 + std::exp(-iter_delta * counter_neg_log_rho_) * counter[idx]; |
| prev_iter[idx] = std::max<int64_t>(curr_iter, prev_iter[idx]); |
| } |
| return true; |
| } |
| |
| protected: |
| int64_t counter_halflife_; |
| double counter_neg_log_rho_; |
| INPUT_TAGS(PREV_ITER, COUNTER, INDICES, ITER); |
| OUTPUT_TAGS(OUTPUT_PREV_ITER, OUTPUT_COUNTER); |
| }; |
| |
| } // namespace caffe2 |