blob: df6fda77bdb183d645ed85380795d3621d285073 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/stump_func_op.h"
namespace caffe2 {
template <>
bool StumpFuncOp<float, float, CPUContext>::RunOnDevice() {
auto& in = Input(0);
const float* in_data = in.template data<float>();
auto* out = Output(0, in.sizes(), at::dtype<float>());
float* out_data = out->template mutable_data<float>();
for (int i = 0; i < in.numel(); i++) {
out_data[i] = (in_data[i] <= threshold_) ? low_value_ : high_value_;
}
return true;
}
template <>
bool StumpFuncIndexOp<float, int64_t, CPUContext>::RunOnDevice() {
auto& in = Input(0);
const float* in_data = in.template data<float>();
int lo_cnt = 0;
for (int i = 0; i < in.numel(); i++) {
lo_cnt += (in_data[i] <= threshold_);
}
auto* out_lo = Output(0, {lo_cnt}, at::dtype<int64_t>());
auto* out_hi = Output(1, {in.numel() - lo_cnt}, at::dtype<int64_t>());
int64_t* lo_data = out_lo->template mutable_data<int64_t>();
int64_t* hi_data = out_hi->template mutable_data<int64_t>();
int lidx = 0;
int hidx = 0;
for (int i = 0; i < in.numel(); i++) {
if (in_data[i] <= threshold_) {
lo_data[lidx++] = i;
} else {
hi_data[hidx++] = i;
}
}
return true;
}
REGISTER_CPU_OPERATOR(StumpFunc, StumpFuncOp<float, float, CPUContext>);
OPERATOR_SCHEMA(StumpFunc)
.NumInputs(1)
.NumOutputs(1)
.Input(0, "X", "tensor of float")
.Output(0, "Y", "tensor of float")
.TensorInferenceFunction([](const OperatorDef&,
const vector<TensorShape>& input_types) {
vector<TensorShape> out(1);
out.at(0) = input_types.at(0);
out.at(0).set_data_type(TensorProto_DataType::TensorProto_DataType_FLOAT);
return out;
})
.SetDoc(R"DOC(
Converts each input element into either high_ or low_value
based on the given threshold.
)DOC");
NO_GRADIENT(StumpFunc);
REGISTER_CPU_OPERATOR(
StumpFuncIndex,
StumpFuncIndexOp<float, int64_t, CPUContext>);
OPERATOR_SCHEMA(StumpFuncIndex)
.NumInputs(1)
.NumOutputs(2)
.Input(0, "X", "tensor of float")
.Output(
0,
"Index_Low",
"tensor of int64 indices for elements below/equal threshold")
.Output(
1,
"Index_High",
"tensor of int64 indices for elements above threshold")
.SetDoc(R"DOC(
Split the elements and return the indices based on the given threshold.
)DOC");
NO_GRADIENT(StumpFuncIndex);
} // namespace caffe2