blob: 998bad2e46b1884ff60c4361780fc5152c44ceee [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <benchmark/benchmark.h>
#include <c10/core/InferenceMode.h>
#include <sstream>
struct ConvParams {
std::vector<int64_t> input;
std::vector<int64_t> weight;
std::vector<int64_t> bias;
std::vector<int64_t> stride;
std::vector<int64_t> padding;
std::vector<int64_t> dilation;
int64_t groups;
};
struct xs {
explicit xs(const std::vector<int64_t>& v_) : v(v_) {}
const std::vector<int64_t>& v;
};
std::ostream& operator<<(std::ostream& os, const xs& x) {
bool first = true;
for (auto const& xx : x.v) {
if (!first) {
os << "x";
}
first = false;
os << xx;
}
return os;
}
std::ostream& operator<<(std::ostream& os, const ConvParams& params) {
os << "I" << xs(params.input) << "_W" << xs(params.weight) << "_B"
<< xs(params.bias) << "_S" << xs(params.stride) << "_P"
<< xs(params.padding) << "_D" << xs(params.dilation) << "_G"
<< params.groups;
return os;
}
std::vector<ConvParams> MobileNetV3Params = {
{{1, 3, 224, 224}, {16, 3, 3, 3}, {16}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 16, 112, 112}, {16, 16, 1, 1}, {16}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 16, 112, 112}, {16, 1, 3, 3}, {16}, {2, 2}, {1, 1}, {1, 1}, 16},
{{1, 16, 56, 56}, {16, 16, 1, 1}, {16}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 16, 56, 56}, {72, 16, 1, 1}, {72}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 72, 56, 56}, {72, 1, 3, 3}, {72}, {2, 2}, {1, 1}, {1, 1}, 72},
{{1, 72, 28, 28}, {24, 72, 1, 1}, {24}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 24, 28, 28}, {88, 24, 1, 1}, {88}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 88, 28, 28}, {88, 1, 3, 3}, {88}, {1, 1}, {1, 1}, {1, 1}, 88},
{{1, 88, 28, 28}, {24, 88, 1, 1}, {24}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 24, 28, 28}, {96, 24, 1, 1}, {96}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 96, 28, 28}, {96, 1, 5, 5}, {96}, {2, 2}, {2, 2}, {1, 1}, 96},
{{1, 96, 14, 14}, {40, 96, 1, 1}, {40}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 40, 14, 14}, {240, 40, 1, 1}, {240}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 240, 14, 14}, {240, 1, 5, 5}, {240}, {1, 1}, {2, 2}, {1, 1}, 240},
{{1, 240, 14, 14}, {40, 240, 1, 1}, {40}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 40, 14, 14}, {240, 40, 1, 1}, {240}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 240, 14, 14}, {240, 1, 5, 5}, {240}, {1, 1}, {2, 2}, {1, 1}, 240},
{{1, 240, 14, 14}, {40, 240, 1, 1}, {40}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 40, 14, 14}, {120, 40, 1, 1}, {120}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 120, 14, 14}, {120, 1, 5, 5}, {120}, {1, 1}, {2, 2}, {1, 1}, 120},
{{1, 120, 14, 14}, {48, 120, 1, 1}, {48}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 48, 14, 14}, {144, 48, 1, 1}, {144}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 144, 14, 14}, {144, 1, 5, 5}, {144}, {1, 1}, {2, 2}, {1, 1}, 144},
{{1, 144, 14, 14}, {48, 144, 1, 1}, {48}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 48, 14, 14}, {288, 48, 1, 1}, {288}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 288, 14, 14}, {288, 1, 5, 5}, {288}, {2, 2}, {2, 2}, {1, 1}, 288},
{{1, 288, 7, 7}, {96, 288, 1, 1}, {96}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 96, 7, 7}, {576, 96, 1, 1}, {576}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 576, 7, 7}, {576, 1, 5, 5}, {576}, {1, 1}, {2, 2}, {1, 1}, 576},
{{1, 576, 7, 7}, {96, 576, 1, 1}, {96}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 96, 7, 7}, {576, 96, 1, 1}, {576}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 576, 7, 7}, {576, 1, 5, 5}, {576}, {1, 1}, {2, 2}, {1, 1}, 576},
{{1, 576, 7, 7}, {96, 576, 1, 1}, {96}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 96, 7, 7}, {576, 96, 1, 1}, {576}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 576, 1, 1}, {1280, 576, 1, 1}, {1280}, {1, 1}, {0, 0}, {1, 1}, 1},
};
std::vector<ConvParams> ResNet18Params = {
{{1, 3, 224, 224}, {64, 3, 7, 7}, {}, {2, 2}, {3, 3}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {128, 64, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {128, 64, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {256, 128, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {256, 128, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {512, 256, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 512, 7, 7}, {512, 512, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {512, 256, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 512, 7, 7}, {512, 512, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 512, 7, 7}, {512, 512, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
};
std::vector<ConvParams> ResNet50Params = {
{{1, 3, 224, 224}, {64, 3, 7, 7}, {}, {2, 2}, {3, 3}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {256, 64, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 64, 56, 56}, {256, 64, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 56, 56}, {64, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {256, 64, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 56, 56}, {64, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 64, 56, 56}, {64, 64, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 64, 56, 56}, {256, 64, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 56, 56}, {128, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 128, 56, 56}, {128, 128, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {512, 128, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 56, 56}, {512, 256, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 512, 28, 28}, {128, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {512, 128, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 28, 28}, {128, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {512, 128, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 28, 28}, {128, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 128, 28, 28}, {128, 128, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 128, 28, 28}, {512, 128, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 28, 28}, {256, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 28, 28}, {256, 256, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 28, 28}, {1024, 512, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {256, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {256, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {256, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {256, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {256, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 256, 14, 14}, {256, 256, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 256, 14, 14}, {1024, 256, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {512, 1024, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 14, 14}, {512, 512, 3, 3}, {}, {2, 2}, {1, 1}, {1, 1}, 1},
{{1, 512, 7, 7}, {2048, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 1024, 14, 14}, {2048, 1024, 1, 1}, {}, {2, 2}, {0, 0}, {1, 1}, 1},
{{1, 2048, 7, 7}, {512, 2048, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 7, 7}, {512, 512, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 512, 7, 7}, {2048, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 2048, 7, 7}, {512, 2048, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
{{1, 512, 7, 7}, {512, 512, 3, 3}, {}, {1, 1}, {1, 1}, {1, 1}, 1},
{{1, 512, 7, 7}, {2048, 512, 1, 1}, {}, {1, 1}, {0, 0}, {1, 1}, 1},
};
struct EnableMklDnn {
explicit EnableMklDnn(bool enable)
: prev_(at::globalContext().userEnabledMkldnn()) {
at::globalContext().setUserEnabledMkldnn(enable);
}
~EnableMklDnn() {
at::globalContext().setUserEnabledMkldnn(prev_);
}
bool prev_;
};
template <bool WithMklDnn>
static void BM_conv2d_native(
benchmark::State& state,
const ConvParams& params) {
EnableMklDnn mkl(WithMklDnn);
auto input = at::randn(params.input);
auto weight = at::randn(params.weight);
auto bias = params.bias.size() > 0 ? at::randn(params.bias) : at::Tensor{};
auto output = at::conv2d(
input,
weight,
bias,
params.stride,
params.padding,
params.dilation,
params.groups);
for (auto _ : state) {
output = at::conv2d(
input,
weight,
bias,
params.stride,
params.padding,
params.dilation,
params.groups);
}
state.counters["GFLOPS/s"] = benchmark::Counter(
2.0f * output.numel() * weight.numel() / weight.size(0) *
state.iterations(),
benchmark::Counter::kIsRate);
state.counters["GB/s"] = benchmark::Counter(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
state.iterations() * (input.nbytes() + weight.nbytes() + output.nbytes()),
benchmark::Counter::kIsRate);
}
enum MklDnnReorder {
None,
WeightOnly,
WeightAndInput,
};
template <MklDnnReorder Reorder>
static void BM_conv2d_mkldnn(
benchmark::State& state,
const ConvParams& params) {
auto input = at::randn(params.input);
auto weight = at::randn(params.weight);
auto bias = params.bias.size() > 0 ? at::randn(params.bias) : at::Tensor{};
if (Reorder == WeightAndInput) {
auto it_input = at::native::itensor_from_mkldnn(input.to_mkldnn());
auto r = ideep::tensor(
params.input, ideep::data_type::f32, ideep::format_tag::aBcd16b);
it_input.reorder_to(r);
input = at::native::new_with_itensor_mkldnn(
std::move(r), at::kFloat, at::Device(at::kCPU));
}
if (Reorder == WeightOnly || Reorder == WeightAndInput) {
weight = at::mkldnn_reorder_conv2d_weight(
weight.to_mkldnn(),
params.padding,
params.stride,
params.dilation,
params.groups);
bias = params.bias.size() > 0 ? bias.to_mkldnn() : bias;
}
auto output = at::mkldnn_convolution(
input,
weight,
bias,
params.padding,
params.stride,
params.dilation,
params.groups);
for (auto _ : state) {
output = at::mkldnn_convolution(
input,
weight,
bias,
params.padding,
params.stride,
params.dilation,
params.groups);
}
state.counters["GFLOPS/s"] = benchmark::Counter(
2.0f * output.numel() * weight.numel() / weight.size(0) *
state.iterations(),
benchmark::Counter::kIsRate);
state.counters["GB/s"] = benchmark::Counter(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
state.iterations() * (input.nbytes() + weight.nbytes() + output.nbytes()),
benchmark::Counter::kIsRate);
}
std::string name(
const char* base,
const char* suffix,
const ConvParams& params) {
std::ostringstream os;
os << base << "_" << suffix << "_" << params;
return os.str();
}
void registerOne(const char* base, const ConvParams& params) {
benchmark::RegisterBenchmark(
name(base, "native", params).data(), BM_conv2d_native<true>, params);
benchmark::RegisterBenchmark(
name(base, "native_nomkl", params).data(),
BM_conv2d_native<false>,
params);
benchmark::RegisterBenchmark(
name(base, "mkldnn_none", params).data(), BM_conv2d_mkldnn<None>, params);
benchmark::RegisterBenchmark(
name(base, "mkldnn_weight", params).data(),
BM_conv2d_mkldnn<WeightOnly>,
params);
benchmark::RegisterBenchmark(
name(base, "mkldnn_input", params).data(),
BM_conv2d_mkldnn<WeightAndInput>,
params);
}
int main(int argc, char** argv) {
c10::InferenceMode guard;
#define BENCH(x) \
for (auto const& params : x##Params) { \
registerOne(#x, params); \
}
BENCH(MobileNetV3);
BENCH(ResNet18);
BENCH(ResNet50);
#undef BENCH
benchmark::Initialize(&argc, argv);
benchmark::RunSpecifiedBenchmarks();
}