blob: 8c26b111de1c36ed94ce303e4395aa182930bc75 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
#define CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/utils/math.h"
#include "c10/util/irange.h"
namespace caffe2 {
template <class Context>
class BisectPercentileOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BisectPercentileOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pct_raw_(OperatorBase::GetRepeatedArgument<float>(
"percentile_raw",
vector<float>{})),
pct_mapping_(OperatorBase::GetRepeatedArgument<float>(
"percentile_mapping",
vector<float>{})),
pct_lower_(OperatorBase::GetRepeatedArgument<float>(
"percentile_lower",
vector<float>{})),
pct_upper_(OperatorBase::GetRepeatedArgument<float>(
"percentile_upper",
vector<float>{})),
pct_lens_(
OperatorBase::GetRepeatedArgument<int>("lengths", vector<int>{})) {
CAFFE_ENFORCE_EQ(
pct_raw_.size(),
pct_mapping_.size(),
"Feature (raw) data and percentile value dimension should match.");
CAFFE_ENFORCE_EQ(
pct_raw_.size(),
pct_lower_.size(),
"Feature (raw) data and lower bound dimension should match.");
CAFFE_ENFORCE_EQ(
pct_raw_.size(),
pct_upper_.size(),
"Feature (raw) data and upper bound dimension should match.");
n_features = pct_lens_.size();
index.resize(n_features + 1);
index[0] = 0;
for (int i = 1; i <= n_features; ++i) {
index[i] = index[i - 1] + pct_lens_[i - 1];
}
CAFFE_ENFORCE_EQ(
index[n_features], // The sum of lengths_data
pct_raw_.size(),
"Sum of lengths should be equal to the total number of percentile "
"mapping data samples");
}
bool RunOnDevice() override {
// Input
const auto& raw = Input(RAW);
CAFFE_ENFORCE_EQ(raw.dim(), 2);
const auto batch_size = raw.size(0);
const auto num_features = raw.size(1);
CAFFE_ENFORCE_EQ(num_features, pct_lens_.size());
const float *const raw_data = raw.template data<float>();
// Output
auto *const pct = Output(PCT, raw.sizes(), at::dtype<float>());
float *const pct_output = pct->template mutable_data<float>();
// Compute percentile for each raw feature value
int feature_start_index = 0;
int feature_length = 0;
int cur_index = 0;
for (const auto i : c10::irange(num_features)) {
cur_index = i;
feature_start_index = index[i];
feature_length = pct_lens_[i];
for (const auto j : c10::irange(batch_size)) {
(void)j; // Suppress unused variable warning
pct_output[cur_index] = compute_percentile(
pct_raw_.begin() + feature_start_index,
pct_mapping_.begin() + feature_start_index,
pct_lower_.begin() + feature_start_index,
pct_upper_.begin() + feature_start_index,
feature_length,
raw_data[cur_index]);
cur_index += num_features;
}
}
return true;
}
protected:
INPUT_TAGS(RAW);
OUTPUT_TAGS(PCT);
private:
int n_features;
vector<float> pct_raw_;
vector<float> pct_mapping_;
vector<float> pct_lower_;
vector<float> pct_upper_;
vector<int> pct_lens_;
vector<int> index;
vector<std::map<float, float>> fast_pct;
static constexpr float kEPSILON = 1e-10;
int64_t binary_search(
const std::vector<float>::iterator& data,
int64_t lo,
int64_t hi,
const float val) {
while (lo < hi) {
const auto mid = lo + (hi - lo) / 2;
const bool low_cond = (data[mid] <= val);
const bool high_cond = (val < data[mid + 1]);
if (low_cond && high_cond) {
return mid;
} else if (!low_cond) {
hi = mid - 1;
} else {
lo = mid + 1;
}
}
return lo;
}
float compute_percentile(
const std::vector<float>::iterator& pct_raw_it,
const std::vector<float>::iterator& pct_mapping_it,
const std::vector<float>::iterator& pct_lower_it,
const std::vector<float>::iterator& pct_upper_it,
const int size,
const float val) {
// Corner cases where no interpolation is needed.
if (val < pct_raw_it[0]) {
return 0.;
}
if (val > pct_raw_it[size - 1]) {
return 1.;
}
// Interpolation by binary search
const auto k = binary_search(pct_raw_it, 0, size - 1, val);
if (pct_raw_it[k] == val) {
// Exact match
return pct_mapping_it[k];
} else {
// interpolation
const float w = (val - pct_raw_it[k]) /
(pct_raw_it[k + 1] - pct_raw_it[k] + kEPSILON);
return (1 - w) * pct_upper_it[k] + w * pct_lower_it[k + 1];
}
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_BISECT_PERCENTILE_OP_H_