blob: 34d2835f9d6767e3298adbf5986821b1408ddeaa [file] [log] [blame]
#pragma once
#include "caffe2/core/operator.h"
#include "c10/util/irange.h"
#include <cmath>
#include <limits>
namespace caffe2 {
template <typename Context>
class QuantileOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
QuantileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
quantile_(this->template GetSingleArgument<float>("quantile", -1.0)),
abs_(this->template GetSingleArgument<bool>("abs", true)),
tol_(this->template GetSingleArgument<float>("tol", 1e-3)) {
CAFFE_ENFORCE_GE(
quantile_,
0,
"input quantile should be ",
"no less than 0, got ",
quantile_);
CAFFE_ENFORCE_GE(
1.0f,
quantile_,
"input quantile should be ",
"no larger than 1, got ",
quantile_);
CAFFE_ENFORCE_GT(
tol_, 0, "tolerance should be ", "no less than 0, got ", tol_);
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
Output(QUANTILE_VAL)->Resize(1);
auto* quantile_val = Output(QUANTILE_VAL)->template mutable_data<T>();
auto& input_zero = Input(0);
int64_t numel = input_zero.numel();
for (const auto i : c10::irange(1, InputSize())) {
CAFFE_ENFORCE_EQ(
Input(i).dtype(),
input_zero.dtype(),
"All inputs must have the same type, expected: ",
input_zero.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
numel += Input(i).numel();
}
CAFFE_ENFORCE_GT(
numel,
0,
"number of total element in input tensor should be ",
"larger than 0, got ",
numel);
// the expected number of elements lessEq to the target value
const int64_t target_cnt =
static_cast<int64_t>(std::ceil(numel * quantile_));
T hi = 0.0;
T lo = 0.0;
GetRangeFromInputs(&lo, &hi);
if (target_cnt == 0) {
// lowest possible value
quantile_val[0] = lo;
return true;
}
if (target_cnt == numel) {
// highest possible value
quantile_val[0] = hi;
return true;
}
int64_t lo_cnt = CountLowerEq(lo);
if (lo_cnt >= target_cnt) {
// the target is one of the lowest value
quantile_val[0] = lo;
return true;
}
while (std::abs(hi - lo) > tol_ * (std::abs(hi) + std::abs(lo))) {
// keep hi_cnt > target_idx and lo_cnt <= target_idx
const T mid = lo + (hi - lo) / 2.0;
const int64_t mid_cnt = CountLowerEq(mid);
if (mid_cnt > target_cnt) {
CAFFE_ENFORCE_NE(
hi, mid, "numeric precision at limit, unable to continue bisect");
hi = mid;
} else if (mid_cnt < target_cnt) {
CAFFE_ENFORCE_NE(
lo, mid, "numeric precision at limit, unable to continue bisect");
lo = mid;
} else {
// mid_cnt == target_cnt
quantile_val[0] = mid;
return true;
}
}
quantile_val[0] = hi;
return true;
}
protected:
float quantile_;
bool abs_;
float tol_;
OUTPUT_TAGS(QUANTILE_VAL);
template <typename T>
void GetRangeFromInputs(T* lo, T* hi) {
*hi = std::numeric_limits<T>::lowest();
*lo = std::numeric_limits<T>::max();
for (const auto i : c10::irange(InputSize())) {
const auto* input = Input(i).template data<T>();
for (const auto j : c10::irange(Input(i).numel())) {
const T val = abs_ ? std::abs(input[j]) : input[j];
if (*hi < val) {
*hi = val;
}
if (*lo > val) {
*lo = val;
}
}
}
}
template <typename T>
int64_t CountLowerEq(const T& thd) {
int64_t count = 0;
for (const auto i : c10::irange(InputSize())) {
const auto* input = Input(i).template data<T>();
for (const auto j : c10::irange(Input(i).numel())) {
const T val = abs_ ? std::abs(input[j]) : input[j];
if (val <= thd) {
count++;
}
}
}
return count;
}
};
} // namespace caffe2