blob: 403746578dffefd9204a3f6d2dcff828028f3efe [file] [log] [blame]
#include <benchmark/benchmark.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
namespace te = torch::jit::tensorexpr;
namespace {
class Gemm : public benchmark::Fixture {
public:
void SetUp(const benchmark::State& state) override {
M = state.range(0);
N = state.range(1);
K = state.range(2);
A = torch::randn({M, K});
B = torch::randn({K, N});
C = torch::mm(A, B);
}
void TearDown(benchmark::State& state) override {
state.counters["GFLOPS"] = benchmark::Counter(
uint64_t(state.iterations()) * 2 * M * N * K,
benchmark::Counter::kIsRate);
}
int M;
int N;
int K;
at::Tensor A;
at::Tensor B;
at::Tensor C;
};
} // namespace
BENCHMARK_DEFINE_F(Gemm, Torch)(benchmark::State& state) {
for (auto _ : state) {
torch::mm_out(C, A, B);
}
}
BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) {
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{K});
te::LoopNest loop({CT});
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT});
for (auto _ : state) {
cg->call({A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>()});
}
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) {
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{K});
te::LoopNest loop({CT});
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr m = loops[0];
loop.splitWithMask(m, 32);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr n = loops[2];
loop.splitWithMask(n, 32);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[1];
te::ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr ni = loops[3];
te::ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[2];
te::ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT});
for (auto _ : state) {
cg->call({A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>()});
}
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) {
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{K});
te::LoopNest loop({CT});
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr m = loops[0];
loop.splitWithMask(m, 4);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[1];
te::ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr ni = loops[3];
te::ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[2];
te::ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT});
for (auto _ : state) {
cg->call({A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>()});
}
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) {
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{K});
te::LoopNest loop({CT});
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr m = loops[0];
loop.splitWithMask(m, 4);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[1];
te::ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr ni = loops[3];
te::ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[2];
te::ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[3];
te::ForPtr ni = loops[4];
te::StmtPtr unrolled;
loop.vectorize(ni);
loop.fullUnroll(mi, &unrolled);
}
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT});
for (auto _ : state) {
cg->call({A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>()});
}
}
BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) {
te::BufHandle AP("A", {M, K}, te::kFloat);
te::BufHandle BP("B", {K, N}, te::kFloat);
te::Tensor CT = te::Reduce(
"gemm",
{M, N},
te::Sum(),
[&](const te::ExprHandle& m,
const te::ExprHandle& n,
const te::ExprHandle& k) { return AP.load(m, k) * BP.load(k, n); },
{K});
te::LoopNest loop({CT});
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr m = loops[0];
loop.splitWithMask(m, 4);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[1];
te::ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr ni = loops[3];
te::ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
te::ForPtr mi = loops[2];
te::ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
}
loop.prepareForCodegen();
te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT});
for (auto _ : state) {
cg->call({A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>()});
}
}
BENCHMARK_REGISTER_F(Gemm, Torch)->Args({128, 128, 128});
BENCHMARK_REGISTER_F(Gemm, TensorExprNoopt)->Args({128, 128, 128});
BENCHMARK_REGISTER_F(Gemm, TensorExprTile32x32)->Args({128, 128, 128});
BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16)->Args({128, 128, 128});
BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16VecUnroll)->Args({128, 128, 128});
BENCHMARK_REGISTER_F(Gemm, TensorExprTile4x16Cache)->Args({128, 128, 128});