| #pragma once |
| |
| #include <c10/util/irange.h> |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| |
| #include <vector> |
| |
| namespace caffe2 { |
| template <typename T, class Context> |
| class KeySplitOp : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit KeySplitOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| categorical_limit_( |
| this->template GetSingleArgument<int>("categorical_limit", 0)) { |
| CAFFE_ENFORCE_GT(categorical_limit_, 0); |
| } |
| |
| bool RunOnDevice() override { |
| auto& keys = Input(0); |
| const auto N = keys.numel(); |
| const T *const keys_data = keys.template data<T>(); |
| std::vector<int> counts(categorical_limit_); |
| std::vector<int*> eids(categorical_limit_); |
| for (const auto k : c10::irange(categorical_limit_)) { |
| counts[k] = 0; |
| } |
| for (const auto i : c10::irange(N)) { |
| const auto k = keys_data[i]; |
| CAFFE_ENFORCE_GT(categorical_limit_, k); |
| CAFFE_ENFORCE_GE(k, 0); |
| counts[k]++; |
| } |
| for (const auto k : c10::irange(categorical_limit_)) { |
| auto *const eid = Output(k, {counts[k]}, at::dtype<int>()); |
| eids[k] = eid->template mutable_data<int>(); |
| counts[k] = 0; |
| } |
| for (const auto i : c10::irange(N)) { |
| const auto k = keys_data[i]; |
| eids[k][counts[k]++] = i; |
| } |
| return true; |
| } |
| |
| private: |
| int categorical_limit_; |
| }; |
| } // namespace caffe2 |