blob: e55ba601b067118464cc9f232aee6264380f439f [file] [log] [blame]
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/predictor/predictor.h"
#include "caffe2/utils/math.h"
#include <gtest/gtest.h>
namespace caffe2 {
namespace {
const char* predictSpec = R"DOC(
name: "predict"
type: "dag"
external_input: "data"
external_input: "W"
external_input: "b"
external_output: "y"
op {
input: "data"
input: "W"
input: "b"
output: "y"
type: "FC"
}
)DOC";
const char* initSpec = R"DOC(
name: "init"
type: "dag"
op {
type: "ConstantFill"
output: "W"
arg {
name: "shape"
ints: 10
ints: 4
}
arg {
name: "value"
f: 2.0
}
}
op {
type: "ConstantFill"
output: "b"
arg {
name: "shape"
ints: 10
}
arg {
name: "value"
f: 2.0
}
}
)DOC";
const char* metaSpec = R"DOC(
blobs {
key: "INPUTS_BLOB_TYPE"
value: "data"
}
blobs {
key: "OUTPUTS_BLOB_TYPE"
value: "y"
}
nets {
key: "GLOBAL_INIT_NET_TYPE"
value: {
name: "init"
type: "dag"
op {
type: "ConstantFill"
output: "data"
arg {
name: "shape"
ints: 1
ints: 4
}
arg {
name: "value"
f: 2.0
}
}
op {
type: "ConstantFill"
output: "W"
arg {
name: "shape"
ints: 10
ints: 4
}
arg {
name: "value"
f: 2.0
}
}
op {
type: "ConstantFill"
output: "b"
arg {
name: "shape"
ints: 10
}
arg {
name: "value"
f: 2.0
}
}
}
}
nets {
key: "PREDICT_NET_TYPE"
value: {
name: "predict"
type: "dag"
external_input: "data"
external_input: "W"
external_input: "b"
external_output: "y"
op {
input: "data"
input: "W"
input: "b"
output: "y"
type: "FC"
}
}
}
)DOC";
std::unique_ptr<Blob> randomTensor(
const std::vector<int64_t>& dims,
CPUContext* ctx) {
auto blob = make_unique<Blob>();
auto* t = BlobGetMutableTensor(blob.get(), CPU);
t->Resize(dims);
math::RandUniform<float, CPUContext>(
t->numel(), -1.0, 1.0, t->template mutable_data<float>(), ctx);
return blob;
}
NetDef parseNetDef(const std::string& value) {
NetDef def;
CAFFE_ENFORCE(
TextFormat::ParseFromString(value, &def),
"Failed to parse NetDef with value: ",
value);
return def;
};
MetaNetDef parseMetaNetDef(const std::string& value) {
MetaNetDef def;
CAFFE_ENFORCE(
TextFormat::ParseFromString(value, &def),
"Failed to parse NetDef with value: ",
value);
return def;
}
} // namespace
class PredictorTest : public testing::Test {
public:
void SetUp() override {
DeviceOption op;
op.set_random_seed(1701);
ctx_ = std::make_unique<CPUContext>(op);
NetDef init, run;
p_ = std::make_unique<Predictor>(
makePredictorConfig(parseNetDef(initSpec), parseNetDef(predictSpec)));
}
std::unique_ptr<CPUContext> ctx_;
std::unique_ptr<Predictor> p_;
};
TEST_F(PredictorTest, SimpleBatchSized) {
auto inputData = randomTensor({1, 4}, ctx_.get());
Predictor::TensorList input;
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
input.emplace_back(tensor->Alias());
Predictor::TensorList output;
(*p_)(input, &output);
EXPECT_EQ(output.size(), 1);
EXPECT_EQ(output.front().sizes().size(), 2);
EXPECT_EQ(output.front().size(0), 1);
EXPECT_EQ(output.front().size(1), 10);
EXPECT_NEAR(output.front().data<float>()[4], 4.9556, 1E-4);
}
TEST_F(PredictorTest, SimpleBatchSizedMapInput) {
auto inputData = randomTensor({1, 4}, ctx_.get());
Predictor::TensorMap input;
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
input.emplace("data", tensor->Alias());
Predictor::TensorList output;
(*p_)(input, &output);
EXPECT_EQ(output.size(), 1);
EXPECT_EQ(output.front().sizes().size(), 2);
EXPECT_EQ(output.front().size(0), 1);
EXPECT_EQ(output.front().size(1), 10);
EXPECT_NEAR(output.front().data<float>()[4], 4.9556, 1E-4);
}
} // namespace caffe2