| #ifndef CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ |
| #define CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/core/tensor.h" |
| #include "caffe2/utils/math.h" |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| class BisectPercentileOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit BisectPercentileOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| pct_raw_(OperatorBase::GetRepeatedArgument<float>( |
| "percentile_raw", |
| vector<float>{})), |
| pct_mapping_(OperatorBase::GetRepeatedArgument<float>( |
| "percentile_mapping", |
| vector<float>{})), |
| pct_lower_(OperatorBase::GetRepeatedArgument<float>( |
| "percentile_lower", |
| vector<float>{})), |
| pct_upper_(OperatorBase::GetRepeatedArgument<float>( |
| "percentile_upper", |
| vector<float>{})), |
| pct_lens_( |
| OperatorBase::GetRepeatedArgument<int>("lengths", vector<int>{})) { |
| CAFFE_ENFORCE_EQ( |
| pct_raw_.size(), |
| pct_mapping_.size(), |
| "Feature (raw) data and percentile value dimension should match."); |
| CAFFE_ENFORCE_EQ( |
| pct_raw_.size(), |
| pct_lower_.size(), |
| "Feature (raw) data and lower bound dimension should match."); |
| CAFFE_ENFORCE_EQ( |
| pct_raw_.size(), |
| pct_upper_.size(), |
| "Feature (raw) data and upper bound dimension should match."); |
| n_features = pct_lens_.size(); |
| index.resize(n_features + 1); |
| index[0] = 0; |
| for (int i = 1; i <= n_features; ++i) { |
| index[i] = index[i - 1] + pct_lens_[i - 1]; |
| } |
| CAFFE_ENFORCE_EQ( |
| index[n_features], // The sum of lengths_data |
| pct_raw_.size(), |
| "Sum of lengths should be equal to the total number of percentile " |
| "mapping data samples"); |
| } |
| |
| bool RunOnDevice() override { |
| // Input |
| const auto& raw = Input(RAW); |
| CAFFE_ENFORCE_EQ(raw.dim(), 2); |
| const auto batch_size = raw.size(0); |
| const auto num_features = raw.size(1); |
| CAFFE_ENFORCE_EQ(num_features, pct_lens_.size()); |
| const float *const raw_data = raw.template data<float>(); |
| |
| // Output |
| |
| auto *const pct = Output(PCT, raw.sizes(), at::dtype<float>()); |
| float *const pct_output = pct->template mutable_data<float>(); |
| |
| // Compute percentile for each raw feature value |
| int feature_start_index = 0; |
| int feature_length = 0; |
| int cur_index = 0; |
| |
| for (const auto i : c10::irange(num_features)) { |
| cur_index = i; |
| feature_start_index = index[i]; |
| feature_length = pct_lens_[i]; |
| for (const auto j : c10::irange(batch_size)) { |
| (void)j; // Suppress unused variable warning |
| pct_output[cur_index] = compute_percentile( |
| pct_raw_.begin() + feature_start_index, |
| pct_mapping_.begin() + feature_start_index, |
| pct_lower_.begin() + feature_start_index, |
| pct_upper_.begin() + feature_start_index, |
| feature_length, |
| raw_data[cur_index]); |
| cur_index += num_features; |
| } |
| } |
| return true; |
| } |
| |
| protected: |
| INPUT_TAGS(RAW); |
| OUTPUT_TAGS(PCT); |
| |
| private: |
| int n_features; |
| vector<float> pct_raw_; |
| vector<float> pct_mapping_; |
| vector<float> pct_lower_; |
| vector<float> pct_upper_; |
| vector<int> pct_lens_; |
| vector<int> index; |
| vector<std::map<float, float>> fast_pct; |
| |
| static constexpr float kEPSILON = 1e-10; |
| |
| int64_t binary_search( |
| const std::vector<float>::iterator& data, |
| int64_t lo, |
| int64_t hi, |
| const float val) { |
| while (lo < hi) { |
| const auto mid = lo + (hi - lo) / 2; |
| const bool low_cond = (data[mid] <= val); |
| const bool high_cond = (val < data[mid + 1]); |
| if (low_cond && high_cond) { |
| return mid; |
| } else if (!low_cond) { |
| hi = mid - 1; |
| } else { |
| lo = mid + 1; |
| } |
| } |
| return lo; |
| } |
| |
| float compute_percentile( |
| const std::vector<float>::iterator& pct_raw_it, |
| const std::vector<float>::iterator& pct_mapping_it, |
| const std::vector<float>::iterator& pct_lower_it, |
| const std::vector<float>::iterator& pct_upper_it, |
| const int size, |
| const float val) { |
| // Corner cases where no interpolation is needed. |
| if (val < pct_raw_it[0]) { |
| return 0.; |
| } |
| if (val > pct_raw_it[size - 1]) { |
| return 1.; |
| } |
| |
| // Interpolation by binary search |
| const auto k = binary_search(pct_raw_it, 0, size - 1, val); |
| |
| if (pct_raw_it[k] == val) { |
| // Exact match |
| return pct_mapping_it[k]; |
| } else { |
| // interpolation |
| const float w = (val - pct_raw_it[k]) / |
| (pct_raw_it[k + 1] - pct_raw_it[k] + kEPSILON); |
| return (1 - w) * pct_upper_it[k] + w * pct_lower_it[k + 1]; |
| } |
| } |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_ |