| #include <gtest/gtest.h> |
| |
| #include <test/cpp/tensorexpr/test_base.h> |
| #include <memory> |
| #include <sstream> |
| #include <stdexcept> |
| #include <unordered_map> |
| |
| #include <test/cpp/tensorexpr/padded_buffer.h> |
| #include <test/cpp/tensorexpr/test_utils.h> |
| #include <torch/csrc/jit/tensorexpr/analysis.h> |
| #include <torch/csrc/jit/tensorexpr/bounds_inference.h> |
| #include <torch/csrc/jit/tensorexpr/eval.h> |
| #include <torch/csrc/jit/tensorexpr/ir.h> |
| #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
| #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
| #include <torch/csrc/jit/tensorexpr/loopnest.h> |
| #include <torch/csrc/jit/tensorexpr/tensor.h> |
| #include <torch/csrc/jit/testing/file_check.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| using namespace torch::jit::tensorexpr; |
| |
| void checkIR(StmtPtr s, const std::string& pattern) { |
| std::ostringstream oss; |
| oss << *s; |
| torch::jit::testing::FileCheck().run(pattern, oss.str()); |
| } |
| |
| void checkExprIR(ExprPtr e, const std::string& pattern) { |
| std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; |
| std::ostringstream oss; |
| oss << *e << "\n"; |
| torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); |
| } |
| |
| void checkExprIR(const ExprHandle& e, const std::string& pattern) { |
| checkExprIR(e.node(), pattern); |
| } |
| |
| TEST(LoopNest, ExprSimple01) { |
| Tensor tensor = |
| Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| LoopNest::splitWithTail(loops[0], 2); |
| LoopNest::splitWithTail(loops[0], 2); |
| } |
| |
| TEST(LoopNest, ExprLower01) { |
| Tensor tensor = |
| Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }); |
| LoopNest l({tensor}); |
| StmtPtr stmt = l.root_stmt(); |
| std::ostringstream oss; |
| oss << *stmt; |
| ASSERT_GT(oss.str().size(), 20); |
| ASSERT_LT(oss.str().size(), 200); |
| } |
| |
| TEST(LoopNest, ExprSimple02) { |
| auto func = [](const ExprHandle& x, const ExprHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }; |
| Tensor tensor = Compute("f", {26, 5}, func); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| LoopNest::splitWithTail(loops[0], 4); |
| |
| StmtPtr stmt = l.root_stmt(); |
| std::ostringstream oss; |
| oss << *stmt; |
| ASSERT_GT(oss.str().size(), 200); |
| ASSERT_LT(oss.str().size(), 600); |
| |
| { |
| // Compare to a reference loop structure structure. |
| VarHandle x_outer("i_outer", kInt); |
| VarHandle x_inner("i_inner", kInt); |
| VarHandle y("i", kInt); |
| VarHandle x_tail("i_tail", kInt); |
| BufHandle f("f", {26, 5}, kFloat); |
| ExprHandle x_1 = x_outer * 4 + x_inner; |
| ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; |
| ForPtr stmt1 = For::make( |
| x_outer, |
| 0, |
| x_outer_end, |
| For::make( |
| x_inner, |
| 0, |
| 4, |
| For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); |
| ExprHandle x_2 = x_tail + x_outer_end * 4; |
| ForPtr stmt2 = For::make( |
| x_tail, |
| 0, |
| (ExprHandle(26) - 0) % 4, |
| For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); |
| StmtPtr stmt = Block::make({stmt1, stmt2}); |
| |
| std::ostringstream oss_ref; |
| oss_ref << *stmt; |
| ASSERT_EQ(oss.str(), oss_ref.str()); |
| } |
| |
| { |
| PaddedBuffer<float> f_v(26, 5, "f_v"); |
| PaddedBuffer<float> f_ref(26, 5, "f_res"); |
| |
| stmt = FlattenIndexes(stmt); |
| SimpleIREvaluator ir_eval(stmt, {tensor}); |
| ir_eval(f_v); |
| |
| for (int x = 0; x < 26; x++) { |
| for (int y = 0; y < 5; y++) { |
| f_ref(x, y) = 1 + x * x + y * y; |
| } |
| } |
| |
| ExpectAllNear(f_v, f_ref, 1e-5); |
| } |
| } |
| |
| BlockPtr getSimplifiedBody(const LoopNest& l) { |
| StmtPtr stmt = l.root_stmt(); |
| StmtPtr simplified = IRSimplifier::simplify(stmt); |
| return to<Block>(simplified); |
| } |
| |
| void assertForRange(ForPtr f, int expected_start, int expected_stop) { |
| ASSERT_NE(f, nullptr); |
| IntImmPtr start = to<IntImm>(f->start()); |
| ASSERT_NE(start, nullptr); |
| ASSERT_EQ(start->value(), expected_start); |
| IntImmPtr stop = to<IntImm>(f->stop()); |
| ASSERT_NE(stop, nullptr); |
| ASSERT_EQ(stop->value(), expected_stop); |
| } |
| |
| void assertForRanges( |
| BlockPtr body, |
| const std::vector<std::pair<int, int>>& start_stops) { |
| ASSERT_EQ(body->nstmts(), start_stops.size()); |
| |
| auto it = body->begin(); |
| for (size_t i = 0; i < start_stops.size(); i++, it++) { |
| ForPtr loop = to<For>(*it); |
| assertForRange(loop, start_stops[i].first, start_stops[i].second); |
| } |
| } |
| |
| TEST(LoopNest, ExprSliceHeadWithLoopOptions) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
| LoopNest::sliceHead(loops[0], 2, &head, &tail); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 2}, {0, 8}}); |
| |
| ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); |
| ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
| |
| ASSERT_TRUE(head->loop_options().isDefault()); |
| } |
| |
| TEST(LoopNest, ExprSliceTailWithLoopOptions) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceTail(loops[0], 4, &head, &tail); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail_head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail_tail; |
| tail->set_gpu_block_index(LoopOptions::IDX_Y); |
| LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); |
| |
| ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); |
| ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
| |
| ASSERT_TRUE(head->loop_options().isDefault()); |
| ASSERT_TRUE(tail_tail->loop_options().isDefault()); |
| } |
| |
| TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { |
| // When factor equals the For loop's original size, keep using the original |
| // For loop. |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceHead(loops[0], 10, &head, &tail); |
| |
| ASSERT_EQ(head, loops[0]); |
| ASSERT_EQ(tail, nullptr); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceHead(loops[0], 100, &head, &tail); |
| |
| ASSERT_EQ(head, loops[0]); |
| ASSERT_EQ(tail, nullptr); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceHead) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceHead(loops[0], 4, &head, &tail); |
| |
| ASSERT_NE(head, nullptr); |
| ASSERT_NE(head, loops[0]); |
| ASSERT_NE(tail, nullptr); |
| ASSERT_EQ(tail, loops[0]); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 4}, {4, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| LoopNest::sliceTail(loops[0], 4, &head, &tail); |
| // head: [0, 6) |
| // tail: [6, 10) |
| |
| LoopNest::sliceHead(tail, 2); |
| // tail_head: [6, 8) |
| // tail_tail: [8, 10) |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { |
| // When factor equals the For loop's original size, keep using the original |
| // For loop. |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceTail(loops[0], 10, &head, &tail); |
| |
| ASSERT_EQ(head, nullptr); |
| ASSERT_EQ(tail, loops[0]); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { |
| // When factor equals the For loop's original size, keep using the original |
| // For loop. |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceTail(loops[0], 100, &head, &tail); |
| |
| ASSERT_EQ(head, nullptr); |
| ASSERT_EQ(tail, loops[0]); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSliceTail) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::sliceTail(loops[0], 4, &head, &tail); |
| |
| ASSERT_NE(head, nullptr); |
| ASSERT_EQ(head, loops[0]); |
| ASSERT_NE(tail, nullptr); |
| ASSERT_NE(tail, loops[0]); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 6}, {6, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSplitAndSlice) { |
| // 0: splitWithTail |
| // 1: sliceTail on inner loop |
| // 2: sliceHead on outer loop |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {100}, func); |
| LoopNest l({tensor}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr inner; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| // outer: [0, 4) |
| // inner: [0, 21) |
| // tail: [84, 100) |
| LoopNest::splitWithTail(loops[0], 21, &inner, &tail); |
| LoopNest::sliceTail(inner, 2); |
| LoopNest::sliceHead(loops[0], 2); |
| |
| // for (int x_outer = 0; x_outer < 2; x_outer++) { |
| // for (int x_inner = 0; x_inner < 19; x_inner++) { |
| // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
| // } |
| // for (int x_inner = 19; x_inner < 21; x_inner++) { |
| // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
| // } |
| // } |
| // for (int x_outer = 2; x_outer < 4; x_outer++) { |
| // for (int x_inner = 0; x_inner < 19; x_inner++) { |
| // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
| // } |
| // for (int x_inner = 19; x_inner < 21; x_inner++) { |
| // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); |
| // } |
| // } |
| // for (int x_tail = 0; x_tail < 16; x_tail++) { |
| // f[x_tail + 84] = 1.f + float(x_tail + 84); |
| // } |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); |
| |
| auto biter = body->begin(); |
| |
| ForPtr loop = to<For>(*biter++); |
| assertForRanges(loop->body(), {{0, 19}, {19, 21}}); |
| |
| loop = to<For>(*biter); |
| assertForRanges(loop->body(), {{0, 19}, {19, 21}}); |
| } |
| |
| TEST(LoopNest, ExprSliceAndNormalize) { |
| // 0: sliceHead |
| // 1: normalize tail |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {10}, func); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| LoopNest::sliceHead(loops[0], 2, &head, &tail); |
| // head: [0, 2) |
| // tail: [2, 10) |
| |
| LoopNest::normalize(tail); |
| // normalized_tail: [0, 8) |
| |
| BlockPtr body = getSimplifiedBody(l); |
| assertForRanges(body, {{0, 2}, {0, 8}}); |
| } |
| |
| template <typename T> |
| T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { |
| ExprEval<SimpleIREvaluator> eval(expr, {var}); |
| return eval.value<T>(value); |
| } |
| |
| TEST(LoopNest, ExprSliceWithVariableDimension) { |
| auto testWithDimension = |
| [](int dimension, |
| const std::vector<std::pair<int, int>>& expected_for_ranges) { |
| VarHandle dim("dim", kInt); |
| Tensor tensor = |
| Compute("f", {dim}, [](const ExprHandle& x) { return x; }); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = |
| l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr head; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr tail; |
| LoopNest::sliceHead(loops[0], 2, &head, &tail); |
| |
| LoopNest::sliceTail(tail, 2); |
| |
| BlockPtr body = getSimplifiedBody(l); |
| ASSERT_EQ(expected_for_ranges.size(), 3); |
| auto it = body->begin(); |
| for (auto& start_stop : expected_for_ranges) { |
| ForPtr loop = to<For>(*it++); |
| int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension); |
| int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension); |
| ASSERT_EQ(start, start_stop.first); |
| ASSERT_EQ(stop, start_stop.second); |
| } |
| }; |
| |
| testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); |
| testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); |
| testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); |
| testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); |
| testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); |
| testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); |
| } |
| |
| TEST(LoopNest, ExprSplitWithTail) { |
| auto func = [](const ExprHandle& x) { |
| return ExprHandle(1.0f) + cast<float>(x); |
| }; |
| Tensor tensor = Compute("f", {199}, func); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| LoopNest::splitWithTail(loops[0], 17); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| LoopNest::splitWithTail(loops[0], 7); |
| |
| StmtPtr stmt = l.root_stmt(); |
| StmtPtr simplified = IRSimplifier::simplify(stmt); |
| BlockPtr body = to<Block>(simplified); |
| ASSERT_EQ(body->nstmts(), 3); |
| auto biter = body->begin(); |
| |
| // Verify that the split loops are ordered correctly. |
| ForPtr loop = to<For>(*biter++); |
| assertForRange(loop, 0, 7); |
| |
| loop = to<For>(*biter++); |
| assertForRange(loop, 0, 4); |
| |
| loop = to<For>(*biter); |
| assertForRange(loop, 0, 12); |
| } |
| |
| TEST(LoopNest, ExprSplitWithTailNone) { |
| auto func = [](const ExprHandle& x, const ExprHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }; |
| Tensor tensor = Compute("f", {24, 5}, func); |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::splitWithTail(loops[0], 4); |
| |
| StmtPtr stmt = l.root_stmt(); |
| std::ostringstream oss; |
| oss << *stmt; |
| ASSERT_GT(oss.str().size(), 200); |
| ASSERT_LT(oss.str().size(), 600); |
| |
| { |
| // Compare to a reference loop structure structure. |
| VarHandle x_outer("i_outer", kInt); |
| VarHandle x_inner("i_inner", kInt); |
| VarHandle y("i", kInt); |
| VarHandle x_tail("i_tail", kInt); |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
| BufHandle f("f", {24, 5}, kFloat); |
| ExprHandle x_1 = x_outer * 4 + x_inner; |
| ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; |
| StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make( |
| x_outer, |
| 0, |
| x_outer_end, |
| For::make( |
| x_inner, |
| 0, |
| 4, |
| For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); |
| |
| std::ostringstream oss_ref; |
| oss_ref << *stmt; |
| ASSERT_EQ(oss.str(), oss_ref.str()); |
| } |
| |
| { |
| PaddedBuffer<float> f_v(24, 5, "f_v"); |
| PaddedBuffer<float> f_ref(24, 5, "f_res"); |
| |
| SimpleIREvaluator ir_eval(stmt, {tensor}); |
| ir_eval(f_v); |
| |
| for (int x = 0; x < 24; x++) { |
| for (int y = 0; y < 5; y++) { |
| f_ref(x, y) = 1 + x * x + y * y; |
| } |
| } |
| |
| ExpectAllNear(f_v, f_ref, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, ExprSplitWithMask01) { |
| const int M = 26; |
| const int N = 5; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {M, N}, kFloat); |
| Tensor tensor = |
| Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; |
| }); |
| |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::splitWithMask(loops[1], 4); |
| |
| StmtPtr stmt = l.root_stmt(); |
| |
| PaddedBuffer<float> a_v(M, N, "a"); |
| PaddedBuffer<float> b_v(M, N, "b"); |
| PaddedBuffer<float> c_v(M, N, "c"); |
| PaddedBuffer<float> c_ref(M, N, "c_ref"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| a_v(m, n) = 2 * m; |
| b_v(m, n) = 3 * n; |
| c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
| } |
| } |
| |
| SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
| |
| ExpectAllNear(c_v, c_ref, 1e-5); |
| } |
| |
| // Tests the case where we split a loop cleanly multiple times, we should not |
| // insert any masks. |
| TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { |
| const int M = 64; |
| BufHandle a_buf("a", {M}, kFloat); |
| BufHandle b_buf("b", {M}, kFloat); |
| Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { |
| return a_buf.load(m) + b_buf.load(m) + 1.0f; |
| }); |
| |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 4); |
| LoopNest::splitWithMask(loops[0], 4); |
| |
| StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); |
| |
| // Two splits mean 3 loops, but should need no masks in this case. |
| checkIR(stmt1, R"IR( |
| # CHECK: for ( |
| # CHECK-NOT: if ( |
| # CHECK: for ( |
| # CHECK-NOT: if ( |
| # CHECK: for ( |
| # CHECK-NOT: if ( |
| # CHECK: f[)IR"); |
| } |
| |
| TEST(LoopNest, getLoopAt) { |
| // Input IR: |
| // for (int i = 0; i < 100; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i, j] = sin(i * j); |
| // for (int k1 = 0; k1 < 200; k1++) { |
| // B[i, j, k1] = (A[i, j]) / (k1 + 1); |
| // } |
| // for (int k2 = 0; k2 < 300; k2++) { |
| // C[i, j, k2] = (A[i, j]) * (k2 + 1); |
| // } |
| // } |
| // } |
| BufPtr A = alloc<Buf>( |
| "A", |
| std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}), |
| kInt); |
| BufPtr B = alloc<Buf>( |
| "B", |
| std::vector<ExprPtr>( |
| {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}), |
| kInt); |
| BufPtr C = alloc<Buf>( |
| "C", |
| std::vector<ExprPtr>( |
| {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}), |
| kInt); |
| BufHandle a_buf(A); |
| BufHandle b_buf(B); |
| BufHandle c_buf(C); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k1("k1", kInt); |
| VarHandle k2("k2", kInt); |
| auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); |
| auto store2 = Store::make( |
| b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); |
| auto store3 = Store::make( |
| c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); |
| auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); |
| auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); |
| auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); |
| auto for_i = For::make(i, 0, 100, for_j); |
| LoopNest l(Block::make({for_i}), {B, C}); |
| auto ret_k2 = l.getLoopAt(for_i, {0, 2}); |
| TORCH_CHECK(ret_k2 == for_k2); |
| |
| std::ostringstream oss; |
| oss << *ret_k2; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int k2 |
| # CHECK-NEXT: C[i, j, k2] = |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, TileSimple) { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| const int M = 64, N = 64; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {M, N}, kFloat); |
| Tensor tensor = |
| Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; |
| }); |
| |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| l.tile(loops[0], loops[1], 4, 8); |
| |
| // IR check |
| StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
| checkIR(stmt, R"IR( |
| # CHECK: for (int i_outer |
| # CHECK: for (int i_outer_1 |
| # CHECK: for (int i_inner |
| # CHECK: for (int i_inner_1 |
| # CHECK: f[ |
| # CHECK-NOT: for (int i_tail |
| # CHECK-NOT: for (int i_tail)IR"); |
| |
| // Correctness check |
| PaddedBuffer<float> a_v(M, N, "a"); |
| PaddedBuffer<float> b_v(M, N, "b"); |
| PaddedBuffer<float> c_v(M, N, "c"); |
| PaddedBuffer<float> c_ref(M, N, "c_ref"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| a_v(m, n) = 2 * m; |
| b_v(m, n) = 3 * n; |
| c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
| } |
| } |
| |
| SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| ExpectAllNear(c_v, c_ref, 1e-5); |
| } |
| |
| TEST(LoopNest, TileWithTails) { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| const int M = 64, N = 64; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {M, N}, kFloat); |
| Tensor tensor = |
| Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; |
| }); |
| |
| LoopNest l({tensor}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| l.tile(loops[0], loops[1], 5, 9); |
| |
| // IR check |
| StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
| checkIR(stmt, R"IR( |
| # CHECK: for (int i_outer |
| # CHECK: for (int i_outer_1 |
| # CHECK: for (int i_inner |
| # CHECK: for (int i_inner_1 |
| # CHECK: f[ |
| # CHECK: for (int i_inner |
| # CHECK: f[ |
| # CHECK: for (int i_tail)IR"); |
| |
| // Correctness check |
| PaddedBuffer<float> a_v(M, N, "a"); |
| PaddedBuffer<float> b_v(M, N, "b"); |
| PaddedBuffer<float> c_v(M, N, "c"); |
| PaddedBuffer<float> c_ref(M, N, "c_ref"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| a_v(m, n) = 2 * m; |
| b_v(m, n) = 3 * n; |
| c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; |
| } |
| } |
| |
| SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| ExpectAllNear(c_v, c_ref, 1e-5); |
| } |
| |
| TEST(LoopNest, TileInMiddle) { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| const int M = 8, N = 8, L = 8, K = 8; |
| BufHandle a_buf("a", {M, N, L, K}, kFloat); |
| BufHandle b_buf("b", {M, N, L, K}, kFloat); |
| Tensor tensor = Compute( |
| "f", |
| {M, N, L, K}, |
| [&](const ExprHandle& m, |
| const ExprHandle& n, |
| const ExprHandle& l, |
| const ExprHandle& k) { |
| return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; |
| }); |
| |
| LoopNest nest({tensor}); |
| std::vector<ForPtr> loops = |
| nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| nest.tile(loops[1], loops[2], 3, 3); |
| |
| // IR check |
| StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); |
| checkIR(stmt, R"IR( |
| # CHECK: for (int i |
| # CHECK: for (int i_outer |
| # CHECK: for (int i_outer_1 |
| # CHECK: for (int i_inner |
| # CHECK: for (int i_inner_1 |
| # CHECK: for (int i_1 |
| # CHECK: f[ |
| # CHECK: for (int i_tail_1 |
| # CHECK: for (int i_inner_1 |
| # CHECK: for (int i_1 |
| # CHECK: f[ |
| # CHECK: for (int i_tail)IR"); |
| |
| // Correctness check |
| PaddedBuffer<float> a_v(M, N, L, K, "a"); |
| PaddedBuffer<float> b_v(M, N, L, K, "b"); |
| PaddedBuffer<float> c_v(M, N, L, K, "c"); |
| PaddedBuffer<float> c_ref(M, N, L, K, "c_ref"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| for (int l = 0; l < L; l++) { |
| for (int k = 0; k < K; k++) { |
| a_v(m, n, l, k) = 2 * (m + l); |
| b_v(m, n, l, k) = 3 * (n + k); |
| c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; |
| } |
| } |
| } |
| } |
| |
| SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| ExpectAllNear(c_v, c_ref, 1e-5); |
| } |
| |
| TEST(LoopNest, SplitWithTailWithLoopOptions) { |
| const int M = 21; |
| BufHandle a_buf("a", {M}, kFloat); |
| BufHandle b_buf("b", {M}, kFloat); |
| Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { |
| return a_buf.load(m) + b_buf.load(m) + 1.0f; |
| }); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr inner, tail; |
| |
| LoopNest l({tensor}); |
| auto loops = NodeFinder<For>::find(l.root_stmt()); |
| ASSERT_GT(loops.size(), 0); |
| loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
| LoopNest::splitWithTail(loops[0], 4, &inner, &tail); |
| ASSERT_NE(inner, nullptr); |
| ASSERT_NE(tail, nullptr); |
| ForPtr outer = loops[0]; |
| |
| // Outer loop carries loop axis bindings. |
| ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); |
| ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
| |
| // Inner loop has none. |
| ASSERT_TRUE(inner->loop_options().isDefault()); |
| |
| // Tail loop has none. |
| ASSERT_TRUE(tail->loop_options().isDefault()); |
| } |
| |
| TEST(LoopNest, SplitWithMaskWithLoopOptions) { |
| const int M = 21; |
| BufHandle a_buf("a", {M}, kFloat); |
| BufHandle b_buf("b", {M}, kFloat); |
| Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { |
| return a_buf.load(m) + b_buf.load(m) + 1.0f; |
| }); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr inner; |
| |
| LoopNest l({tensor}); |
| auto loops = NodeFinder<For>::find(l.root_stmt()); |
| loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); |
| LoopNest::splitWithMask(loops[0], 4, &inner); |
| ForPtr outer = loops[0]; |
| |
| // Outer loop carries loop axis bindings. |
| ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); |
| ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); |
| |
| // Inner loop has none. |
| ASSERT_TRUE(inner->loop_options().isDefault()); |
| } |
| |
| TEST(LoopNest, ScheduleBroadcastAddBuffer) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| LoopNest l({c}); |
| StmtPtr stmt = l.root_stmt(); |
| |
| PaddedBuffer<float> a_v(M, N, "a_v"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| a_v(m, n) = 7 * m * n; |
| } |
| } |
| a_v.Backup(); |
| |
| PaddedBuffer<float> b_v(N, K, "b_v"); |
| for (int n = 0; n < N; n++) { |
| for (int k = 0; k < K; k++) { |
| b_v(n, k) = 11 * n * k; |
| } |
| } |
| b_v.Backup(); |
| |
| PaddedBuffer<float> c_v(M, N, K, "c_buf"); |
| SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); |
| ir_eval(a_v, b_v, c_v); |
| |
| a_v.CheckBackup(); |
| b_v.CheckBackup(); |
| PaddedBuffer<float> c_ref(M, N, K, "c_ref"); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| for (int k = 0; k < K; k++) { |
| c_ref(m, n, k) = 7 * m * n + 11 * n * k; |
| } |
| } |
| } |
| ExpectAllNear(c_v, c_ref, 1e-5); |
| } |
| |
| TEST(LoopNest, ScheduleFunctionCall01) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| Tensor d = Compute( |
| "d", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c.load(m, n, k) + 1; |
| }); |
| |
| LoopNest l({d}, {c, d}); |
| l.prepareForCodegen(); |
| StmtPtr stmt = l.root_stmt(); |
| std::ostringstream oss; |
| oss << *stmt; |
| ASSERT_GT(oss.str().size(), 100); |
| |
| PaddedBuffer<float> a_v(M, N); |
| PaddedBuffer<float> b_v(N, K); |
| PaddedBuffer<float> c_v(M, N, K); |
| PaddedBuffer<float> d_v(M, N, K); |
| PaddedBuffer<float> d_ref(M, N, K); |
| |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a_v(i, j) = i * i; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| for (int j = 0; j < K; j++) { |
| b_v(i, j) = j * j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| for (int k = 0; k < K; k++) { |
| d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; |
| } |
| } |
| } |
| |
| SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); |
| eval(a_v, b_v, d_v); |
| |
| ExpectAllNear(d_v, d_ref, 1e-5); |
| } |
| |
| TEST(LoopNest, ScheduleInlineSimple) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| BufHandle c_buf("c", {M, N}, kFloat); |
| BufHandle d_buf("d", {M, K}, kFloat); |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) * b_buf.load(n, k); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
| }); |
| |
| LoopNest l1({y}, {x, y}); |
| LoopNest l2(l1); |
| l2.computeInline(x.buf()); |
| |
| l1.prepareForCodegen(); |
| l2.prepareForCodegen(); |
| |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
| |
| SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); |
| SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); |
| |
| PaddedBuffer<float> a_v(M, N); |
| PaddedBuffer<float> b_v(N, K); |
| PaddedBuffer<float> c_v(M, N); |
| PaddedBuffer<float> d_v(M, K); |
| |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a_v(i, j) = i * i; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| for (int j = 0; j < K; j++) { |
| b_v(i, j) = j * j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| c_v(i, j) = i + j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < K; j++) { |
| d_v(i, j) = i * j; |
| } |
| } |
| |
| PaddedBuffer<float> y_1(M, N, K); |
| PaddedBuffer<float> y_2(M, N, K); |
| |
| eval1(a_v, b_v, c_v, d_v, y_1); |
| eval2(a_v, b_v, c_v, d_v, y_2); |
| ExpectAllNear(y_1, y_2, 1e-5); |
| std::ostringstream oss1, oss2; |
| oss1 << *stmt1; |
| oss2 << *stmt2; |
| ASSERT_GT(oss1.str().size(), oss2.str().size()); |
| } |
| |
| static std::string remove_space(const std::string& str) { |
| std::string str_new = str; |
| str_new.erase( |
| remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); |
| return str_new; |
| } |
| |
| void InlineFunc01Helper(const std::vector<std::string>& inline_order) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| BufHandle c_buf("c", {M, N}, kFloat); |
| BufHandle d_buf("d", {M, K}, kFloat); |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) * b_buf.load(n, k); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
| }); |
| Tensor z = Compute( |
| "z", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m, n, k) + y.load(m, n, k); |
| }); |
| |
| LoopNest l({z}, {x, y, z}); |
| for (const std::string& order : inline_order) { |
| if (order == "x") { |
| l.computeInline(x.buf()); |
| } else if (order == "y") { |
| l.computeInline(y.buf()); |
| } else { |
| throw std::runtime_error("Invalid order: " + order); |
| } |
| } |
| l.prepareForCodegen(); |
| StmtPtr stmt = l.root_stmt(); |
| |
| std::ostringstream oss; |
| oss << *stmt; |
| std::string str1 = remove_space(oss.str()); |
| |
| { |
| PaddedBuffer<float> a_v(M, N); |
| PaddedBuffer<float> b_v(N, K); |
| PaddedBuffer<float> c_v(M, N); |
| PaddedBuffer<float> d_v(M, K); |
| |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a_v(i, j) = i * i; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| for (int j = 0; j < K; j++) { |
| b_v(i, j) = j * j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| c_v(i, j) = i + j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < K; j++) { |
| d_v(i, j) = i * j; |
| } |
| } |
| |
| PaddedBuffer<float> z_v(M, N, K); |
| PaddedBuffer<float> z_ref(M, N, K); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| for (int k = 0; k < K; k++) { |
| z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); |
| } |
| } |
| } |
| |
| SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); |
| eval(a_v, b_v, c_v, d_v, z_v); |
| ExpectAllNear(z_v, z_ref, 1e-5); |
| } |
| |
| if (inline_order.size() == 2) { |
| Tensor z2 = Compute( |
| "z", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) * b_buf.load(n, k) + |
| (c_buf.load(m, n) * d_buf.load(m, k) + |
| a_buf.load(m, n) * b_buf.load(n, k)); |
| }); |
| LoopNest l2({z2}); |
| l2.prepareForCodegen(); |
| StmtPtr stmt2 = l2.root_stmt(); |
| |
| std::ostringstream oss2; |
| oss2 << *stmt2; |
| std::string str2 = remove_space(oss2.str()); |
| |
| ASSERT_EQ(str1, str2); |
| ASSERT_GT(str1.size(), 100); |
| } |
| } |
| |
| TEST(LoopNest, ScheduleInlineFunc01) { |
| InlineFunc01Helper({"x", "y"}); |
| InlineFunc01Helper({"y", "x"}); |
| InlineFunc01Helper({"x"}); |
| InlineFunc01Helper({"y"}); |
| InlineFunc01Helper({}); |
| } |
| |
| // Make sure we cache random vars if we should. |
| TEST(LoopNest, ScheduleInlineRandom) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return Mod::make(Intrinsics::make(kRand, kInt), 5); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m, n, k) + x.load(m, n, k); |
| }); |
| |
| LoopNest l1({y}, {x, y}); |
| l1.computeInline(x.buf()); |
| |
| // would normally compare results but Rand isn't implemented in the |
| // SimpleIREvaluator, even if we could seed it. |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| |
| // Check the IR we produced |
| checkIR(stmt1, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) |
| # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
| # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
| # CHECK: int x = rand(); |
| # CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); |
| } |
| |
| // Make sure we don't cache random vars that are not being inlined. |
| TEST(LoopNest, ScheduleInlineRandomUnrelated) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return m * n * k; |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + |
| Intrinsics::make(kRand, kInt); |
| }); |
| |
| LoopNest l1({y}, {x, y}); |
| l1.computeInline(x.buf()); |
| |
| // would normally compare results but Rand isn't implemented in the |
| // SimpleIREvaluator, even if we could seed it. |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| |
| // Check the IR we produced |
| checkIR(stmt1, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) |
| # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
| # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
| # CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); |
| } |
| |
| // Make sure we generate the right number of random values == the dimensionality |
| // of the production tensor. |
| TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| |
| Tensor x = Compute("x", {M}, [&](const VarHandle& m) { |
| return Mod::make(Intrinsics::make(kRand, kInt), 5); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m) + x.load(m); |
| }); |
| |
| LoopNest l1({y}, {x, y}); |
| l1.computeInline(x.buf()); |
| |
| // would normally compare results but Rand isn't implemented in the |
| // SimpleIREvaluator, even if we could seed it. |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| |
| // Check the IR we produced |
| checkIR(stmt1, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) |
| # CHECK: int x = rand(); |
| # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
| # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
| # CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); |
| } |
| |
| // Make sure we don't screw up intrinsics thinking they're rand. |
| TEST(LoopNest, ScheduleInlineIntrinsics) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) * b_buf.load(n, k); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return Intrinsics::make(kSqrt, x.load(m, n, k)); |
| }); |
| |
| PaddedBuffer<float> a_v(M, N); |
| PaddedBuffer<float> b_v(N, K); |
| |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a_v(i, j) = i * i; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| for (int j = 0; j < K; j++) { |
| b_v(i, j) = j * j; |
| } |
| } |
| |
| LoopNest l1({y}, {x, y}); |
| LoopNest l2(l1); |
| l2.computeInline(x.buf()); |
| |
| l1.prepareForCodegen(); |
| l2.prepareForCodegen(); |
| |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); |
| |
| SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); |
| SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); |
| |
| PaddedBuffer<float> y_1(M, N, K); |
| PaddedBuffer<float> y_2(M, N, K); |
| |
| eval1(a_v, b_v, y_1); |
| eval2(a_v, b_v, y_2); |
| ExpectAllNear(y_1, y_2, 1e-5); |
| std::ostringstream oss1, oss2; |
| oss1 << *stmt1; |
| oss2 << *stmt2; |
| ASSERT_GT(oss1.str().size(), oss2.str().size()); |
| } |
| |
| // Make sure we can handle rand and non-rand intrinsics. |
| TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return Intrinsics::make(kRand, kFloat); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return Intrinsics::make(kSqrt, x.load(m, n, k)); |
| }); |
| |
| LoopNest l1({y}, {x, y}); |
| l1.computeInline(x.buf()); |
| |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| |
| // Check the IR we produced |
| checkIR(stmt1, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) |
| # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
| # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
| # CHECK: float x = rand(); |
| # CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); |
| } |
| |
| // Split a Compute then inline it into another compute. |
| TEST(LoopNest, ScheduleSplitAThenInline) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| |
| LoopNest l({b}, {a, b}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 4); |
| ASSERT_FALSE(l.computeInline(a.buf())); |
| } |
| |
| // Split a Compute then inline another Compute into it. |
| TEST(LoopNest, ScheduleSplitBThenInline) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| |
| LoopNest l({b}, {a, b}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 3); |
| l.computeInline(a.buf()); |
| l.prepareForCodegen(); |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| |
| std::vector<int> output(6, 0); |
| SimpleIREvaluator eval(s, {b}); |
| eval(output); |
| |
| for (int i = 0; i < 6; ++i) { |
| ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
| } |
| } |
| |
| // Split a Compute twice then inline it. |
| TEST(LoopNest, ScheduleSplitTwiceThenInline) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr i_inner; |
| |
| LoopNest l({b}, {a, b}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 4, &i_inner); |
| LoopNest::splitWithMask(i_inner, 2); |
| ASSERT_FALSE(l.computeInline(a.buf())); |
| } |
| |
| // Inline a Compute, then split. |
| TEST(LoopNest, ScheduleInlineThenSplit) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| |
| LoopNest l({b}, {a, b}); |
| l.computeInline(a.buf()); |
| |
| std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt()); |
| LoopNest::splitWithMask(loops.back(), 3); |
| l.prepareForCodegen(); |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| std::vector<int> output(6, 0); |
| SimpleIREvaluator eval(s, {b}); |
| eval(output); |
| |
| for (int i = 0; i < 6; ++i) { |
| ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
| } |
| } |
| |
| // Split a Compute, inline it, then split the result. |
| TEST(LoopNest, ScheduleSplitInlineThenSplit) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| |
| LoopNest l({b}, {a, b}); |
| auto loops = NodeFinder<For>::find(l.root_stmt()); |
| LoopNest::splitWithMask(loops.back(), 2); |
| l.computeInline(a.buf()); |
| |
| loops = NodeFinder<For>::find(l.root_stmt()); |
| LoopNest::splitWithMask(loops.front(), 2); |
| l.prepareForCodegen(); |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| std::vector<int> output(16, 0); |
| SimpleIREvaluator eval(s, {b}); |
| eval(output); |
| |
| for (int i = 0; i < 16; ++i) { |
| ASSERT_EQ(output[i], (i + 8) * (i + 8)); |
| } |
| } |
| |
| // Oversplit a loop that is simplified out after inlining. |
| TEST(LoopNest, ScheduleSplitInlineSimplify) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { |
| return ExprHandle(4) * i - ExprHandle(2) * i; |
| }); |
| Tensor b = Compute( |
| "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); |
| |
| LoopNest l({b}, {a, b}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 4); |
| ASSERT_FALSE(l.computeInline(a.buf())); |
| } |
| |
| // Inline a Compute with two consumers. |
| TEST(LoopNest, ScheduleInlineThreeMixedOnce) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
| return a.load(k) * b.load(l); |
| }); |
| |
| LoopNest l({c}, {a, b, c}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| l.computeInline(a.buf()); |
| l.prepareForCodegen(); |
| |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| std::vector<int> output(4 * 3, 0); |
| SimpleIREvaluator eval(s, {c}); |
| eval(output); |
| |
| for (int k = 0; k < 4; ++k) { |
| for (int l = 0; l < 3; ++l) { |
| ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
| } |
| } |
| } |
| |
| // Inline Compute A into B, then inline B into C. |
| TEST(LoopNest, ScheduleInlineThreeMixedTwice) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
| return a.load(k) * b.load(l); |
| }); |
| |
| LoopNest l({c}, {a, b, c}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| l.computeInline(a.buf()); |
| l.computeInline(b.buf()); |
| l.prepareForCodegen(); |
| |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| std::vector<int> output(4 * 3, 0); |
| SimpleIREvaluator eval(s, {c}); |
| eval(output); |
| |
| for (int k = 0; k < 4; ++k) { |
| for (int l = 0; l < 3; ++l) { |
| ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
| } |
| } |
| } |
| |
| // Inline a Compute that is both a producer and consumer. |
| TEST(LoopNest, ScheduleInlineThreeMixedInner) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
| return a.load(k) * b.load(l); |
| }); |
| |
| LoopNest l({c}, {a, b, c}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| l.computeInline(b.buf()); |
| l.prepareForCodegen(); |
| |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| std::vector<int> output(4 * 3, 0); |
| SimpleIREvaluator eval(s, {c}); |
| eval(output); |
| |
| for (int k = 0; k < 4; ++k) { |
| for (int l = 0; l < 3; ++l) { |
| ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); |
| } |
| } |
| } |
| |
| // Split 3 Computes, then inline the first two into the last. |
| TEST(LoopNest, ScheduleInlineThreeMixedSplit) { |
| Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); |
| Tensor b = Compute( |
| "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); |
| Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { |
| return a.load(k) * b.load(l); |
| }); |
| |
| LoopNest l({c}, {a, b, c}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 4); |
| loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 3); |
| loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::splitWithMask(loops[0], 2); |
| |
| ASSERT_FALSE(l.computeInline(a.buf())); |
| } |
| |
| // Check that inlining works for output tensors too |
| TEST(LoopNest, ScheduleInlineOutputTensors) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return m * n * k; |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m, n, k) + m; |
| }); |
| |
| LoopNest l1({x, y}); |
| l1.computeInline(x.buf()); |
| |
| // would normally compare results but Rand isn't implemented in the |
| // SimpleIREvaluator, even if we could seed it. |
| StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); |
| |
| // Check the IR we produced |
| checkIR(stmt1, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) |
| # CHECK: for (int i_1 = 0; i_1 < 5; i_1++) |
| # CHECK: for (int i_2 = 0; i_2 < 6; i_2++) |
| # CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; |
| # CHECK: for (int i_3 = 0; i_3 < 4; i_3++) |
| # CHECK: for (int i_4 = 0; i_4 < 5; i_4++) |
| # CHECK: for (int i_5 = 0; i_5 < 6; i_5++) |
| # CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); |
| } |
| |
| TEST(LoopNest, ScheduleInlineWithCompoundIndices) { |
| // Input IR: |
| // for (int64_t i = 0; i < 100; i++) { |
| // A[i*2,i] = i * 500ll; |
| // } |
| // for (int64_t j = 0; j < 100; j++) { |
| // B[0ll,j] = A[0, j] + j * 100ll; |
| // } |
| BufHandle a_buf("A", {20, 100}, kLong); |
| BufHandle b_buf("B", {20, 100}, kLong); |
| VarHandle i("i", kLong); |
| VarHandle j("j", kLong); |
| auto forI = For::make( |
| i, |
| 0, |
| 100, |
| Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast<int64_t>(500)))); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| b_buf, |
| {static_cast<int64_t>(0), j}, |
| Add::make( |
| Load::make(a_buf, {static_cast<int64_t>(0), j}), |
| Mul::make(j, static_cast<int64_t>(100))))); |
| auto par = Block::make({forI, forJ}); |
| |
| LoopNest l(par, {b_buf.node()}); |
| // Inlining should fail since the producer has compound expr as index. |
| ASSERT_FALSE(l.computeInline(a_buf.node())); |
| |
| // The input statement must remain as is. |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int64_t i = 0; |
| # CHECK-NEXT: A[ |
| # CHECK: for (int64_t j = 0; |
| # CHECK-NEXT: B[)IR"); |
| } |
| |
| TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { |
| // Input IR: |
| // for (int64_t i = 0; i < 100; i++) { |
| // A[0ll,i] = i * 500ll; |
| // } |
| // for (int64_t j = 0; j < 100; j++) { |
| // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; |
| // } |
| BufHandle a_buf("A", {20, 100}, kLong); |
| BufHandle b_buf("B", {20, 100}, kLong); |
| VarHandle i("i", kLong); |
| VarHandle j("j", kLong); |
| auto forI = For::make( |
| i, |
| 0, |
| 100, |
| Store::make( |
| a_buf, |
| {static_cast<int64_t>(0), i}, |
| Mul::make(i, static_cast<int64_t>(500)))); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| b_buf, |
| {static_cast<int64_t>(0), j}, |
| Add::make( |
| Load::make(a_buf, {0, j}), |
| Mul::make(j, static_cast<int64_t>(100))))); |
| auto par = Block::make({forI, forJ}); |
| |
| LoopNest l(par, {b_buf.node()}); |
| ASSERT_TRUE(l.computeInline(a_buf.node())); |
| |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int64_t j = 0; j < 100; j++) { |
| # CHECK: B[0ll, j] = j * 500ll + j * 100ll; |
| # CHECK: })IR"); |
| } |
| |
| TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { |
| // Input IR: |
| // for (int64_t i = 0; i < 100; i++) { |
| // A[(int64_t)0,i] = i * 500ll; |
| // } |
| // for (int64_t j = 0; j < 100; j++) { |
| // B[0ll,j] = A[0ll, j] + j * 100ll; |
| // } |
| BufHandle a_buf("A", {20, 100}, kLong); |
| BufHandle b_buf("B", {20, 100}, kLong); |
| VarHandle i("i", kLong); |
| VarHandle j("j", kLong); |
| auto forI = For::make( |
| i, |
| 0, |
| 100, |
| Store::make(a_buf, {0, i}, Mul::make(i, static_cast<int64_t>(500)))); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| b_buf, |
| {static_cast<int64_t>(0), j}, |
| Add::make( |
| Load::make(a_buf, {static_cast<int64_t>(0), j}), |
| Mul::make(j, static_cast<int64_t>(100))))); |
| auto par = Block::make({forI, forJ}); |
| |
| LoopNest l(par, {b_buf.node()}); |
| ASSERT_TRUE(l.computeInline(a_buf.node())); |
| |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int64_t j = 0; j < 100; j++) { |
| # CHECK: B[0ll, j] = j * 500ll + j * 100ll; |
| # CHECK: })IR"); |
| } |
| |
| TEST(LoopNest, ScheduleFuserStyle) { |
| const int kVectorSize = 8; |
| const int kVectorCount = 128; |
| const int kTotalSize = kVectorSize * kVectorCount; |
| |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); |
| |
| Tensor b = |
| Compute("f", {kTotalSize}, [&](const std::vector<VarHandle>& axes) { |
| return a_buf.load(axes[0]) + 11.0f; |
| }); |
| |
| Tensor c = |
| Compute("g", {kTotalSize}, [&](const std::vector<VarHandle>& axes) { |
| return b.load(axes[0]) + 1.0f; |
| }); |
| |
| LoopNest l({b, c}); |
| l.prepareForCodegen(); |
| StmtPtr s = l.root_stmt(); |
| |
| std::vector<float> a_data(kTotalSize, 7.0f); |
| std::vector<float> b_data(kTotalSize, 0.0f); |
| std::vector<float> c_data(kTotalSize, 0.0f); |
| SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); |
| |
| for (int i = 0; i < kTotalSize; i++) { |
| ASSERT_EQ(b_data[i], 18.0f); |
| ASSERT_EQ(c_data[i], 19.0f); |
| } |
| } |
| |
| TEST(LoopNest, ScheduleFuserThreeArg) { |
| const int kVectorSize = 8; |
| const int kVectorCount = 128; |
| const int kTotalSize = kVectorSize * kVectorCount; |
| |
| BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); |
| BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); |
| BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); |
| BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); |
| |
| Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { |
| return a.load(i) + b.load(i); |
| }); |
| Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { |
| return e.load(i) + c.load(i); |
| }); |
| Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { |
| return f.load(i) + d.load(i); |
| }); |
| |
| LoopNest l({g}, {e, f, g}); |
| l.computeInline(l.getLoopBodyFor(e)); |
| l.computeInline(l.getLoopBodyFor(f)); |
| l.prepareForCodegen(); |
| StmtPtr s = l.root_stmt(); |
| |
| std::vector<float> a_data(kTotalSize, 1.0f); |
| std::vector<float> b_data(kTotalSize, 2.0f); |
| std::vector<float> c_data(kTotalSize, 3.0f); |
| std::vector<float> d_data(kTotalSize, 4.0f); |
| std::vector<float> g_data(kTotalSize, 0.0f); |
| SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); |
| |
| for (int i = 0; i < kTotalSize; i++) { |
| ASSERT_EQ(g_data[i], 10.0f); |
| } |
| } |
| |
| TEST(LoopNest, ScheduleDynamicShape2D) { |
| auto testWithSize = [](int32_t M, int32_t N) { |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| BufHandle a("a", {m, n}, kFloat); |
| BufHandle b("b", {m, n}, kFloat); |
| Tensor c = |
| Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { |
| return a.load(i, j) + b.load(i, j); |
| }); |
| LoopNest l({c}); |
| StmtPtr s = l.root_stmt(); |
| SimpleIREvaluator cg(s, {a, b, c, m, n}); |
| std::vector<float> aData(M * N, 1.0f); |
| std::vector<float> bData(M * N, 2.0f); |
| std::vector<float> cData(M * N, 0.0f); |
| cg.call({aData, bData, cData, M, N}); |
| ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7); |
| }; |
| testWithSize(1, 8); |
| testWithSize(16, 32); |
| testWithSize(37, 11); |
| } |
| |
| TEST(LoopNest, LoopNestComputeAt_1) { |
| // Verify that compute_at works on the following example: |
| // |
| // for (int i_a = 0; i_a < N; i_a++) { |
| // A[i_a] = i_a * i_a |
| // } |
| // for (int i_b = 0; i_b < N; i_b++) { |
| // B[i_b] = A[i_b] |
| // } |
| // |
| // After the transformation the i_b loop should have an allocation for a temp |
| // buffer and that buffer should be used in computation of B. No use of A |
| // should be in that loop after the transformation. Also, computation of A |
| // should not be inlined into B. Instead, it should be computed into the temp, |
| // and the temp should be used in B. |
| VarHandle N("N", kInt); |
| Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); |
| Tensor B = |
| Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); |
| LoopNest l({B}, {A, B}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {B, N}); |
| StmtPtr s = cg.stmt(); |
| |
| checkIR(s, R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[1] |
| # CHECK: for (int i = 0; i < N; i++) |
| # CHECK: temp[ |
| # CHECK-NOT: A[ |
| # CHECK: B[i_1] = temp[0] |
| # CHECK: Free(temp))IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> b_data(100, 0); |
| cg.call({b_data, 100}); |
| |
| std::vector<int> b_ref(100, 0); |
| for (int i = 0; i < 100; i++) { |
| b_ref[i] = i * i; |
| } |
| assertAllEqual(b_data, b_ref); |
| } |
| |
| TEST(LoopNest, LoopNestComputeAt_2) { |
| // Verify that compute_at works on the following example: |
| // |
| // for (int py = 0; py < H+1; py++) { |
| // for (int px = 0; px < W+1; px++) { |
| // p[py, px] = py*px |
| // } |
| // } |
| // for (int cy = 0; cy < H; cy++) { |
| // for (int cx = 0; cx < W; cx++) { |
| // c[py, px] = p[cy,cx] + p[cy+1,cx] + |
| // p[cy,cx+1] + p[cy+1,cx+1] |
| // } |
| // } |
| |
| const int kW = 16, kH = 16; |
| VarHandle W("W", kInt); |
| VarHandle H("H", kInt); |
| Tensor p = Compute( |
| "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { |
| return px * py; |
| }); |
| Tensor c = |
| Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { |
| return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + |
| p.load(y + 1, x + 1); |
| }); |
| |
| std::vector<int> c_ref(kW * kH, 0); |
| for (int y = 0; y < kH; y++) { |
| for (int x = 0; x < kW; x++) { |
| c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); |
| } |
| } |
| LoopNest orig_loopnest({c}, {p, c}); |
| |
| { |
| // First let's try to compute P at axis cy (the outer loop) |
| LoopNest l(orig_loopnest); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
| StmtPtr s = cg.stmt(); |
| |
| // Check the IR we produced |
| checkIR(s, R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] |
| # CHECK: for (int i_2 = 0; i_2 < H; i_2++) |
| # CHECK: for |
| # CHECK: for |
| # CHECK: for (int i_3 = 0; i_3 < W; i_3++) |
| # CHECK-NOT: prod[ |
| # CHECK: cons[ |
| # CHECK: Free(temp))IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| |
| assertAllEqual(c_data, c_ref); |
| } |
| { |
| // Now let's try to compute P at axis cx (the inner loop) |
| LoopNest l(orig_loopnest); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
| StmtPtr s = cg.stmt(); |
| |
| // Check the IR we produced |
| checkIR(s, R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[2, 2] |
| # CHECK: for (int i_2 = 0; i_2 < H; i_2++) |
| # CHECK: for (int i_3 = 0; i_3 < W; i_3++) |
| # CHECK: for |
| # CHECK: for |
| # CHECK-NOT: prod[ |
| # CHECK: cons[ |
| # CHECK: Free(temp))IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| |
| assertAllEqual(c_data, c_ref); |
| } |
| } |
| |
| TEST(LoopNest, LoopNestComputeAt_3) { |
| // Verify that compute_at works on the following example: |
| // |
| // A(x,y) = x*y |
| // B(x,y) = A(x, y) |
| // C(x,y) = B(x+1, y) |
| // D(x,y) = A(x, y+1) + C(x, y) |
| // |
| // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. |
| |
| const int kW = 16, kH = 16; |
| VarHandle W("W", kInt); |
| VarHandle H("H", kInt); |
| Tensor A = Compute( |
| "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { |
| return ax * ay; |
| }); |
| Tensor B = Compute( |
| "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { |
| return A.load(by, bx); |
| }); |
| Tensor C = |
| Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { |
| return B.load(cy, cx + 1); |
| }); |
| Tensor D = |
| Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { |
| return A.load(dy + 1, dx) + C.load(dy, dx); |
| }); |
| |
| std::vector<int> c_ref(kW * kH, 0); |
| for (int y = 0; y < kH; y++) { |
| for (int x = 0; x < kW; x++) { |
| c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); |
| } |
| } |
| |
| LoopNest orig_loopnest({D}, {A, B, C, D}); |
| { |
| // First let's try to compute A at axis dy (the outer loop) |
| LoopNest l(orig_loopnest); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); |
| StmtPtr s = cg.stmt(); |
| |
| // Check the IR we produced |
| checkIR(s, R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[1, W] |
| # CHECK: for (int i = 0; i < H + 1; i++) |
| # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) |
| # CHECK: A[ |
| # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) |
| # CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) |
| # CHECK: B[ |
| # CHECK: for (int i_4 = 0; i_4 < H; i_4++) |
| # CHECK: for (int i_5 = 0; i_5 < W; i_5++) |
| # CHECK: C[ |
| # CHECK: for (int i_6 = 0; i_6 < H; i_6++) |
| # CHECK: for (int i_7 = 0; i_7 < W; i_7++) |
| # CHECK-NOT: A[)IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| |
| assertAllEqual(c_data, c_ref); |
| } |
| { |
| // Now let's try to compute A at axis dx (the inner loop) |
| LoopNest l(orig_loopnest); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); |
| StmtPtr s = cg.stmt(); |
| |
| // Check the IR we produced |
| checkIR(s, R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[1, 1] |
| # CHECK: for (int i = 0; i < H + 1; i++) |
| # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) |
| # CHECK: A[ |
| # CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) |
| # CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) |
| # CHECK: B[ |
| # CHECK: for (int i_4 = 0; i_4 < H; i_4++) |
| # CHECK: for (int i_5 = 0; i_5 < W; i_5++) |
| # CHECK: C[ |
| # CHECK: for (int i_6 = 0; i_6 < H; i_6++) |
| # CHECK: for (int i_7 = 0; i_7 < W; i_7++) |
| # CHECK-NOT: A[)IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| |
| assertAllEqual(c_data, c_ref); |
| } |
| } |
| |
| using Axis = const VarHandle&; |
| |
| TEST(LoopNest, Reduce2dComputeAt) { |
| const int kW = 16, kH = 16; |
| VarHandle W("W", kInt); |
| VarHandle H("H", kInt); |
| |
| Tensor p = Compute( |
| "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); |
| Tensor c = Reduce( |
| "cons", |
| {H, W}, |
| Sum(), |
| [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, |
| {2, 2}); |
| |
| std::vector<int> c_ref(kW * kH, 0); |
| for (int y = 0; y < kH; y++) { |
| for (int x = 0; x < kW; x++) { |
| c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); |
| } |
| } |
| LoopNest orig_loopnest({c}, {p, c}); |
| checkIR(orig_loopnest.root_stmt(), R"IR( |
| # CHECK: for (int i = 0; i < H + 1; i++) { |
| # CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { |
| # CHECK: prod[i, i_1] = i_1 * i; |
| # CHECK: } |
| # CHECK: } |
| # CHECK: for (int i_2 = 0; i_2 < H; i_2++) { |
| # CHECK: for (int i_3 = 0; i_3 < W; i_3++) { |
| # CHECK: cons[i_2, i_3] = int(0); |
| # CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { |
| # CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { |
| # CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| |
| { |
| // First let's try to compute P at axis cy (the outer loop) |
| LoopNest l(orig_loopnest); |
| auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); |
| // FIXME: Calling simplify here breaks the IR: |
| // MALFORMED INPUT: could not find base node in Load - temp[...] |
| // l.simplify(); |
| l.eliminateDeadStores(); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
| checkIR(cg.stmt(), R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] |
| # CHECK: for (int i = 0; i < H; i++) { |
| # CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { |
| # CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { |
| # CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: for (int i_1 = 0; i_1 < W; i_1++) { |
| # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); |
| # CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { |
| # CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { |
| # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| # CHECK: Free(temp); |
| )IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| assertAllEqual(c_data, c_ref); |
| } |
| { |
| // Now let's try to compute P at axis cx (the inner loop) |
| LoopNest l(orig_loopnest); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); |
| l.simplify(); |
| l.eliminateDeadStores(); |
| l.prepareForCodegen(); |
| SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); |
| checkIR(cg.stmt(), R"IR( |
| # CHECK: Allocate(temp); // dtype=int, dims=[2, 2] |
| # CHECK: for (int i = 0; i < H; i++) { |
| # CHECK: for (int i_1 = 0; i_1 < W; i_1++) { |
| # CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { |
| # CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { |
| # CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; |
| # CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { |
| # CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { |
| # CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| # CHECK: Free(temp); |
| )IR"); |
| |
| // Now check that the loop still produces the correct result. |
| std::vector<int> c_data(kW * kH, 0); |
| cg.call({c_data, kW, kH}); |
| assertAllEqual(c_data, c_ref); |
| } |
| } |
| |
| TEST(LoopNest, DISABLED_Conv1d_NH) { |
| // Lots of stuff is broken here. The computeAt swaps the axes for some odd |
| // reason. Even without that, the index flattener fails due to "dimensions |
| // mismatch in flatten index". |
| |
| int N = 4; |
| int H = 256; |
| int R = 3; |
| int Pad = 1; |
| BufHandle IP("input", {H}, kFloat); |
| |
| Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { |
| auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); |
| cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); |
| return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); |
| }); |
| Tensor B = Reduce( |
| "B", |
| {N, H}, |
| Sum(), |
| [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, |
| {R}); |
| LoopNest l({B}); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int np = 0; np < 4; np++) { |
| # CHECK: for (int hp = 0; hp < 258; hp++) { |
| # CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: for (int n = 0; n < 4; n++) { |
| # CHECK: for (int h = 0; h < 256; h++) { |
| # CHECK: B[n, h] = float(0); |
| # CHECK: for (int r = 0; r < 3; r++) { |
| # CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); |
| LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); |
| // FIXME: The current IR is totally broken. The body of the inlined loop is: |
| |
| // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), |
| // 0.f, input[idx1 + 0, (idx0 + n) - 1]); |
| |
| // Which seems to mix up the axes. The CHECK below is my best guess at what |
| // the input "should" look like |
| |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int n = 0; n < 4; n++) { |
| # CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { |
| # CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { |
| temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: for (int h = 0; h < 256; h++) { |
| # CHECK: B[n, h] = float(0); |
| # CHECK: for (int r = 0; r < 3; r++) { |
| # CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); |
| # CHECK: } |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| |
| l.simplify(); |
| l.prepareForCodegen(); |
| StmtPtr s = l.root_stmt(); |
| |
| SimpleIREvaluator cg(s, {IP, B}); |
| // auto At = at::ones({N, H}, at::kFloat); |
| auto At = at::arange(N * H, at::kFloat).reshape({N, H}); |
| auto Rt = at::conv1d( |
| At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); |
| auto Bt = at::empty_like(Rt); |
| cg.call({At.data_ptr<float>(), Bt.data_ptr<float>()}); |
| ASSERT_TRUE(at::allclose(Rt, Bt)); |
| } |
| |
| class LoopOrderHelper : public IRVisitor { |
| std::stringstream ordering; |
| |
| public: |
| std::string getOrder(StmtPtr s) { |
| ordering.str(""); |
| s->accept(this); |
| return ordering.str(); |
| } |
| |
| void visit(const ForPtr& v) final { |
| ordering << v->var()->name_hint() << ","; |
| IRVisitor::visit(v); |
| } |
| }; |
| |
| TEST(LoopNest, LoopNestReorderAxis1) { |
| Tensor tensor = |
| Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }); |
| LoopNest l({tensor}); |
| StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| std::vector<int> stmt1_output(6, 0); |
| SimpleIREvaluator cg(stmt1, {tensor}); |
| cg.call({stmt1_output}); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[0], loops[1]); |
| StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| ASSERT_NE(stmt1, stmt2); |
| LoopOrderHelper loopOrderHelper; |
| std::string order1 = loopOrderHelper.getOrder(stmt1); |
| std::string order2 = loopOrderHelper.getOrder(stmt2); |
| |
| ASSERT_EQ(order1, "j,i,"); |
| ASSERT_EQ(order2, "i,j,"); |
| |
| std::vector<int> stmt2_output(6, 0); |
| SimpleIREvaluator cg2(stmt2, {tensor}); |
| cg.call({stmt2_output}); |
| |
| for (int i = 0; i < 6; ++i) { |
| ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
| } |
| |
| // Reorder them back. |
| loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[0], loops[1]); |
| StmtPtr stmt3 = l.root_stmt(); |
| |
| std::string order3 = loopOrderHelper.getOrder(stmt3); |
| ASSERT_EQ(order3, order1); |
| |
| std::ostringstream oss1, oss2; |
| oss1 << *stmt1; |
| oss2 << *stmt3; |
| |
| // Should be identical to the unreordered statement. |
| ASSERT_EQ(oss1.str(), oss2.str()); |
| } |
| |
| TEST(LoopNest, LoopNestReorderPartialAxes) { |
| Tensor tensor = Compute( |
| "f", |
| {2, 3, 4}, |
| [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y + |
| cast<float>(z) * z; |
| }); |
| LoopNest l({tensor}); |
| |
| LoopOrderHelper loopOrderHelper; |
| StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); |
| |
| std::vector<int> stmt1_output(24, 0); |
| SimpleIREvaluator cg(stmt1, {tensor}); |
| cg.call({stmt1_output}); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[0], loops[1]); |
| ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); |
| |
| StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
| |
| std::vector<int> stmt2_output(24, 0); |
| SimpleIREvaluator cg2(stmt2, {tensor}); |
| cg2.call({stmt2_output}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
| } |
| |
| loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[1], loops[2]); |
| ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); |
| |
| StmtPtr stmt3 = Stmt::clone(l.root_stmt()); |
| |
| std::vector<int> stmt3_output(24, 0); |
| SimpleIREvaluator cg3(stmt3, {tensor}); |
| cg3.call({stmt3_output}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(stmt1_output[i], stmt3_output[i]); |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderInternalAxis) { |
| Tensor tensor = Compute( |
| "f", |
| {1, 2, 3, 4}, |
| [](const VarHandle& w, |
| const VarHandle& x, |
| const VarHandle& y, |
| const VarHandle& z) { |
| return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y + |
| cast<float>(z) * z; |
| }); |
| LoopNest l({tensor}); |
| |
| LoopOrderHelper loopOrderHelper; |
| StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); |
| |
| std::vector<int> stmt1_output(24, 0); |
| SimpleIREvaluator cg(stmt1, {tensor}); |
| cg.call({stmt1_output}); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[2], loops[1]); |
| ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); |
| |
| StmtPtr stmt2 = l.root_stmt(); |
| |
| std::vector<int> stmt2_output(24, 0); |
| SimpleIREvaluator cg2(stmt2, {tensor}); |
| cg2.call({stmt2_output}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderEnclosingAxis) { |
| Tensor tensor = Compute( |
| "f", |
| {1, 2, 3, 4}, |
| [](const VarHandle& w, |
| const VarHandle& x, |
| const VarHandle& y, |
| const VarHandle& z) { |
| return ExprHandle(1.0f) + w + cast<float>(x) * x + cast<float>(y) * y + |
| cast<float>(z) * z; |
| }); |
| LoopNest l({tensor}); |
| |
| LoopOrderHelper loopOrderHelper; |
| StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| std::vector<int> stmt1_output(24, 0); |
| SimpleIREvaluator cg(stmt1, {tensor}); |
| cg.call({stmt1_output}); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[0], loops[3]); |
| ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); |
| |
| StmtPtr stmt2 = l.root_stmt(); |
| |
| std::vector<int> stmt2_output(24, 0); |
| SimpleIREvaluator cg2(stmt2, {tensor}); |
| cg2.call({stmt2_output}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(stmt1_output[i], stmt2_output[i]); |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderSameAxis) { |
| Tensor tensor = |
| Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }); |
| LoopNest l({tensor}); |
| StmtPtr stmt1 = Stmt::clone(l.root_stmt()); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[1], loops[1]); |
| StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
| |
| std::ostringstream oss, oss2; |
| oss << *stmt1; |
| oss2 << *stmt2; |
| ASSERT_EQ(oss.str(), oss2.str()); |
| } |
| |
| TEST(LoopNest, LoopNestReorderExtraStatements) { |
| /* We're going for a structure like this: |
| * for i in ... |
| * Stmt 1 |
| * for j in ... |
| * Stmt 2 |
| * for k in ... |
| * Stmt 3 |
| * Stmt 4 |
| */ |
| |
| Tensor tensor = Compute( |
| "f", |
| {2, 3, 4}, |
| [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y + |
| cast<float>(z) * z; |
| }); |
| LoopNest l({tensor}); |
| |
| BufHandle extra("res", {6, 3}, kFloat); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| |
| VarHandle i = VarHandle(loops[0]->var()); |
| |
| StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); |
| StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); |
| // stmt 3 is the Function body. |
| StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); |
| |
| loops[0]->body()->prepend_stmt(store_1); |
| loops[1]->body()->prepend_stmt(store_2); |
| loops[1]->body()->append_stmt(store_3); |
| StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| std::vector<int> extra1(6, 0); |
| std::vector<int> res1(24, 0); |
| SimpleIREvaluator cg(stmt1, {tensor, extra}); |
| cg.call({res1, extra1}); |
| |
| /* Then we reorder loop y and z, we want it to look like: |
| * |
| * for i in ... |
| * Stmt 1 |
| * for j in ... |
| * Stmt 2 |
| * for j_1 in ... |
| * for k in ... |
| * Stmt 3 |
| * for j_2 in ... |
| * Stmt 4 |
| * |
| * We need extra loops because we don't have dependency info about stmt 3 |
| * and 4. |
| * |
| */ |
| |
| LoopNest::reorderAxis(loops[1], loops[2]); |
| StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| // Check the IR we produced |
| checkIR(stmt2, R"IR( |
| # CHECK: for |
| # CHECK: res[i, 0] = 1 |
| # CHECK: for |
| # CHECK: res[i, 1] = 2 |
| # CHECK: for |
| # CHECK: for |
| # CHECK: f[ |
| # CHECK: for |
| # CHECK: res[i, 2] = 4 |
| )IR"); |
| |
| std::vector<int> extra2(6, 0); |
| std::vector<int> res2(24, 0); |
| SimpleIREvaluator cg2(stmt2, {tensor, extra}); |
| cg2.call({res2, extra2}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(res1[i], res2[i]); |
| } |
| for (int i = 0; i < 6; ++i) { |
| ASSERT_EQ(extra1[i], extra2[i]); |
| } |
| |
| /* Now reorder x and the y above stmt 3: |
| * |
| * |
| * for x in ... |
| * Stmt 1 |
| * for y in ... |
| * Stmt 2 |
| * |
| * for y in ... |
| * for z in ... |
| * for x in ... |
| * Stmt 3 |
| * |
| * for x in ... |
| * for y in ... |
| * Stmt 4 |
| * |
| * |
| */ |
| loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); |
| LoopNest::reorderAxis(loops[0], loops[2]); |
| StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); |
| |
| // Check the IR we produced |
| checkIR(stmt3, R"IR( |
| # CHECK: for |
| # CHECK: res[i, 0] = 1 |
| # CHECK: for |
| # CHECK: res[i, 1] = 2 |
| # CHECK: for |
| # CHECK: for |
| # CHECK: for |
| # CHECK: f[ |
| # CHECK: for |
| # CHECK: for |
| # CHECK: res[i_2, 2] = 4 |
| )IR"); |
| |
| std::vector<int> extra3(6, 0); |
| std::vector<int> res3(24, 0); |
| SimpleIREvaluator cg3(stmt3, {tensor, extra}); |
| cg3.call({res3, extra3}); |
| |
| for (int i = 0; i < 24; ++i) { |
| ASSERT_EQ(res1[i], res3[i]); |
| } |
| for (int i = 0; i < 6; ++i) { |
| ASSERT_EQ(extra1[i], extra3[i]); |
| } |
| } |
| |
| void LoopNestReorderTestHelper( |
| bool prepend, |
| bool append, |
| int index1, |
| int index2) { |
| Tensor c = Compute( |
| "5d", {2, 3, 2, 3, 2}, [](const std::vector<VarHandle>&) { return -1; }); |
| LoopNest l({c}); |
| |
| BufHandle extra("extra", {5}, kInt); |
| |
| auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| int j = 0; |
| for (auto l : loops) { |
| // Add an increment at each layer of the loop which counts the number of |
| // times the loop executes. |
| LoadPtr load = |
| alloc<Load>(extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)})); |
| AddPtr add = alloc<Add>(load, alloc<IntImm>(1)); |
| StmtPtr store = alloc<Store>( |
| extra.node(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add); |
| if (prepend) { |
| l->body()->prepend_stmt(store); |
| } |
| if (append) { |
| l->body()->append_stmt(Stmt::clone(store)); |
| } |
| |
| j++; |
| } |
| |
| StmtPtr stmt1 = Stmt::clone(l.root_stmt()); |
| |
| std::vector<int> extra1(5, 0); |
| std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0); |
| SimpleIREvaluator cg(stmt1, {c, extra}); |
| cg.call({res1, extra1}); |
| |
| std::vector<int> loopExtents = {2, 3, 2, 3, 2}; |
| |
| int expected_loops = 0; |
| if (prepend) { |
| expected_loops++; |
| } |
| if (append) { |
| expected_loops++; |
| } |
| for (int i = 0; i < 5; ++i) { |
| expected_loops *= loopExtents[i]; |
| ASSERT_EQ(extra1[i], expected_loops); |
| } |
| |
| loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); |
| LoopNest::reorderAxis(loops[index1], loops[index2]); |
| StmtPtr stmt2 = Stmt::clone(l.root_stmt()); |
| |
| std::ostringstream oss, oss2; |
| oss << *stmt1; |
| oss2 << *stmt2; |
| ASSERT_NE(oss.str(), oss2.str()); |
| |
| std::vector<int> extra2(5, 0); |
| std::vector<int> res2(2 * 3 * 2 * 3 * 2, 0); |
| SimpleIREvaluator cg2(stmt2, {c, extra}); |
| cg2.call({res2, extra2}); |
| |
| expected_loops = 0; |
| if (prepend) { |
| expected_loops++; |
| } |
| if (append) { |
| expected_loops++; |
| } |
| |
| for (int i = 0; i < 5; ++i) { |
| expected_loops *= loopExtents[i]; |
| ASSERT_EQ(extra2[i], expected_loops); |
| } |
| |
| for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { |
| ASSERT_EQ(res2[i], res1[i]); |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { |
| for (int i = 0; i < 5; ++i) { |
| for (int j = 0; j < 5; ++j) { |
| // skip noops, since we check the loop isn't the same after reordering. |
| if (i != j) { |
| LoopNestReorderTestHelper(true, false, i, j); |
| } |
| } |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { |
| for (int i = 0; i < 5; ++i) { |
| for (int j = 0; j < 5; ++j) { |
| // skip noops, since we check the loop isn't the same after reordering. |
| if (i != j) { |
| LoopNestReorderTestHelper(false, true, i, j); |
| } |
| } |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderLongStringFull) { |
| for (int i = 0; i < 5; ++i) { |
| for (int j = 0; j < 5; ++j) { |
| // skip noops, since we check the loop isn't the same after reordering. |
| if (i != j) { |
| LoopNestReorderTestHelper(true, true, i, j); |
| } |
| } |
| } |
| } |
| |
| TEST(LoopNest, LoopNestReorderInternalLoopNest) { |
| const int M = 4; |
| const int N = 5; |
| const int K = 6; |
| BufHandle a_buf("a", {M, N}, kFloat); |
| BufHandle b_buf("b", {N, K}, kFloat); |
| BufHandle c_buf("c", {M, N}, kFloat); |
| BufHandle d_buf("d", {M, K}, kFloat); |
| |
| Tensor x = Compute( |
| "x", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) * b_buf.load(n, k); |
| }); |
| Tensor y = Compute( |
| "y", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); |
| }); |
| Tensor z = Compute( |
| "z", |
| {M, N, K}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return x.load(m, n, k) + y.load(m, n, k); |
| }); |
| |
| LoopNest l({z}, {x, y, z}); |
| ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; |
| ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; |
| LoopNest::reorderAxis(a, b); |
| |
| l.prepareForCodegen(); |
| StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
| |
| // Check the IR we produced has the 3 nests in the right order, but k and m |
| // swapped in the middle. |
| checkIR(stmt, R"IR( |
| # CHECK: < 4 |
| # CHECK: < 5 |
| # CHECK: < 6 |
| # CHECK: < 6 |
| # CHECK: < 5 |
| # CHECK: < 4 |
| # CHECK: < 4 |
| # CHECK: < 5 |
| # CHECK: < 6)IR"); |
| |
| { |
| PaddedBuffer<float> a_v(M, N); |
| PaddedBuffer<float> b_v(N, K); |
| PaddedBuffer<float> c_v(M, N); |
| PaddedBuffer<float> d_v(M, K); |
| |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a_v(i, j) = i * i; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| for (int j = 0; j < K; j++) { |
| b_v(i, j) = j * j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| c_v(i, j) = i + j; |
| } |
| } |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < K; j++) { |
| d_v(i, j) = i * j; |
| } |
| } |
| |
| PaddedBuffer<float> z_v(M, N, K); |
| PaddedBuffer<float> z_ref(M, N, K); |
| for (int m = 0; m < M; m++) { |
| for (int n = 0; n < N; n++) { |
| for (int k = 0; k < K; k++) { |
| z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); |
| } |
| } |
| } |
| |
| SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); |
| eval(a_v, b_v, c_v, d_v, z_v); |
| ExpectAllNear(z_v, z_ref, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, OuterLoopVectorization) { |
| Tensor tensor = |
| Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { |
| return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y; |
| }); |
| LoopNest l({tensor}); |
| |
| ASSERT_TRUE( |
| LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); |
| |
| StmtPtr root_stmt = l.root_stmt(); |
| BlockPtr outer_block = to<Block>(root_stmt); |
| ASSERT_NE(outer_block, nullptr); |
| while (BlockPtr inner_block = to<Block>(outer_block->front())) { |
| outer_block = inner_block; |
| } |
| |
| // Verify that we have only a single loop level remaining after |
| // vectorization. |
| ASSERT_EQ(outer_block->nstmts(), 1); |
| ForPtr for_loop = to<For>(outer_block->front()); |
| ASSERT_NE(for_loop, nullptr); |
| BlockPtr for_body = for_loop->body(); |
| ASSERT_EQ(for_body->nstmts(), 1); |
| ASSERT_EQ(to<For>(for_body->front()), nullptr); |
| } |
| |
| TEST(LoopNest, VectorizeLoopNotNormalized) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // for (int j = 1; j < 5; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 1, 5, for_body); |
| auto outer_for = For::make(i, 0, 10, inner_for); |
| auto block = Block::make({outer_for}); |
| LoopNest l(block, {a_buf.node()}); |
| |
| ASSERT_TRUE(LoopNest::vectorize(inner_for)); |
| ASSERT_EQ(outer_for->body()->nstmts(), 1); |
| ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr); |
| } |
| |
| namespace { |
| |
| std::string constantUpperBoundLoopIR(int upper_bound_val) { |
| ExprHandle upper_bound(upper_bound_val); |
| Tensor A = |
| Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); |
| LoopNest l({A}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll(loops[0], &unrolled); |
| std::ostringstream oss; |
| oss << *unrolled; |
| return oss.str(); |
| } |
| |
| } // namespace |
| |
| TEST(LoopNest, Unroll) { |
| const std::string actual = constantUpperBoundLoopIR(3); |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: A[0] = 0; |
| # CHECK: A[1] = 2; |
| # CHECK: A[2] = 4)IR"; |
| |
| torch::jit::testing::FileCheck().run(verification_pattern, actual); |
| } |
| |
| TEST(LoopNest, UnrollOuter) { |
| ExprHandle outer_bound(3); |
| ExprHandle inner_bound(4); |
| Tensor A = Compute( |
| "A", |
| {outer_bound, inner_bound}, |
| [&](const VarHandle& x, const VarHandle& y) { return x + y; }); |
| LoopNest l({A}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll(loops[0], &unrolled); |
| checkIR(unrolled, R"IR( |
| # CHECK: for (int i = 0; i < 4; i++) { |
| # CHECK: A[0, i] = i; |
| # CHECK: } |
| # CHECK: for (int i = 0; i < 4; i++) { |
| # CHECK: A[1, i] = i + 1; |
| # CHECK: } |
| # CHECK: for (int i = 0; i < 4; i++) { |
| # CHECK: A[2, i] = i + 2; |
| # CHECK: })IR"); |
| } |
| |
| TEST(LoopNest, UnrollInner) { |
| ExprHandle outer_bound(3); |
| ExprHandle inner_bound(4); |
| Tensor A = Compute( |
| "A", |
| {outer_bound, inner_bound}, |
| [&](const VarHandle& x, const VarHandle& y) { return x + y; }); |
| LoopNest l({A}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll( |
| static_to<For>(loops[0]->body()->stmts().front()), &unrolled); |
| checkIR(loops[0], R"IR( |
| # CHECK: for (int i = 0; i < 3; i++) { |
| # CHECK: A[i, 0] = i; |
| # CHECK: A[i, 1] = i + 1; |
| # CHECK: A[i, 2] = i + 2; |
| # CHECK: A[i, 3] = i + 3; |
| # CHECK: })IR"); |
| } |
| |
| TEST(LoopNest, UnrollMultipleStatements) { |
| const int kTotalSize = 3; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| |
| VarHandle x("x", kInt); |
| auto f = For::make( |
| x, |
| 0, |
| kTotalSize, |
| Block::make( |
| {Store::make(a_buf, {x}, x * 2), |
| Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); |
| auto parent_block = Block::make({f}); |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll(f, &unrolled); |
| checkIR(unrolled, R"IR( |
| # CHECK: A[0] = 0; |
| # CHECK: B[0] = A[0]; |
| # CHECK: A[1] = 2; |
| # CHECK: B[1] = A[1]; |
| # CHECK: A[2] = 4 |
| # CHECK: B[2] = A[2];)IR"); |
| } |
| |
| TEST(LoopNest, UnrollNonLiteralConstantBounds) { |
| // Input IR: |
| // for (int i = 2 - 1; i < 12 / 3; i++) { |
| // for (int j = 0; j < 4; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {3, 4}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 0, 4, for_body); |
| auto outer_for = For::make( |
| i, |
| IntImm::make(2) - IntImm::make(1), |
| IntImm::make(12) / IntImm::make(3), |
| inner_for); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto b = Block::make({outer_for}); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll(loops[0], &unrolled); |
| checkIR(unrolled, R"IR( |
| # CHECK: for (int j = 0; j < 4; j++) { |
| # CHECK: A[1, j] = j; |
| # CHECK: } |
| # CHECK: for (int j = 0; j < 4; j++) { |
| # CHECK: A[2, j] = 2 * j; |
| # CHECK: } |
| # CHECK: for (int j = 0; j < 4; j++) { |
| # CHECK: A[3, j] = 3 * j; |
| # CHECK: })IR"); |
| } |
| |
| TEST(LoopNest, UnrollNonConstantBounds) { |
| // Input IR: |
| // for (int i = 0; i < M; i++) { |
| // for (int j = 0; j < N; j++) { |
| // A[i, j] = i * j; |
| // } |
| // } |
| VarHandle M("M", kInt); |
| VarHandle N("N", kInt); |
| BufHandle a_buf("A", {M, N}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 0, N, for_body); |
| auto outer_for = For::make(i, 0, M, inner_for); |
| auto block = Block::make({outer_for}); |
| LoopNest l(block, {a_buf.node()}); |
| |
| LoopNest::unroll(inner_for, 8); |
| l.simplify(); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int i = 0; i < M; i++) { |
| # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { |
| # CHECK: A[i, 8 * j_outer] = |
| # CHECK: A[i, 8 * j_outer + 1] = |
| # CHECK: A[i, 2 * (4 * j_outer + 1)] = |
| # CHECK: A[i, 8 * j_outer + 3] = |
| # CHECK: A[i, 4 * (2 * j_outer + 1)] = |
| # CHECK: A[i, 8 * j_outer + 5] = |
| # CHECK: A[i, 8 * j_outer + 6] = |
| # CHECK: A[i, 8 * j_outer + 7] = |
| # CHECK: } |
| # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { |
| # CHECK: A[i, 8 * (N / 8) + j_tail] = |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| } |
| |
| TEST(LoopNest, UnrollByFactorsLessThan2) { |
| // Input IR: |
| // for (int i = 0; i < M; i++) { |
| // for (int j = 0; j < N; j++) { |
| // A[i, j] = i * j; |
| // } |
| // } |
| VarHandle M("M", kInt); |
| VarHandle N("N", kInt); |
| BufHandle a_buf("A", {M, N}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 0, N, for_body); |
| auto outer_for = For::make(i, 0, M, inner_for); |
| auto block = Block::make({outer_for}); |
| LoopNest l(block, {a_buf.node()}); |
| |
| // Unrolling by factor = 1 should do nothing. |
| LoopNest::unroll(inner_for, 1); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int i = 0; i < M; i++) { |
| # CHECK: for (int j = 0; j < N; j++) { |
| # CHECK: A[i, j] = |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| |
| // Unrolling by factor = 0 should do nothing. |
| LoopNest::unroll(inner_for, 0); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int i = 0; i < M; i++) { |
| # CHECK: for (int j = 0; j < N; j++) { |
| # CHECK: A[i, j] = |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| |
| // Unrolling by negative factor should do nothing. |
| LoopNest::unroll(inner_for, -2); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int i = 0; i < M; i++) { |
| # CHECK: for (int j = 0; j < N; j++) { |
| # CHECK: A[i, j] = |
| # CHECK: } |
| # CHECK: } |
| )IR"); |
| } |
| |
| TEST(LoopNest, UnrollByFactorEqualToIters) { |
| // Input IR: |
| // for (int i = 0; i < 5; i++) { |
| // A[i] = i * i; |
| // } |
| BufHandle a_buf("A", {5}, kInt); |
| VarHandle i("i", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); |
| auto for_loop = For::make(i, 0, 5, for_body); |
| auto block = Block::make({for_loop}); |
| LoopNest l(block, {a_buf.node()}); |
| |
| LoopNest::unroll(for_loop, 5); |
| checkIR(l.root_stmt(), R"IR( |
| # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) |
| # CHECK: A[5 * i_outer] |
| # CHECK: A[5 * i_outer + 1] |
| # CHECK: A[5 * i_outer + 2] |
| # CHECK: A[5 * i_outer + 3] |
| # CHECK: A[5 * i_outer + 4] |
| )IR"); |
| } |
| |
| TEST(LoopNest, UnrollEmpty) { |
| const std::string actual = constantUpperBoundLoopIR(0); |
| const std::string& verification_pattern = R"IR( |
| # CHECK-NOT: A[ |
| )IR"; |
| |
| torch::jit::testing::FileCheck().run(verification_pattern, actual); |
| } |
| |
| TEST(LoopNest, NoUnroll) { |
| VarHandle upper_bound("N", kInt); |
| Tensor A = |
| Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); |
| LoopNest l({A}); |
| std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; |
| StmtPtr unrolled = nullptr; |
| ASSERT_THROWS_WITH( |
| LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); |
| } |
| |
| TEST(LoopNest, UnrollWithLet) { |
| const int kTotalSize = 3; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| |
| VarHandle e("e", kInt); |
| VarHandle x("x", kInt); |
| auto f = For::make( |
| x, |
| 0, |
| kTotalSize, |
| Block::make( |
| {Let::make(e, 7), |
| Store::make(a_buf, {x}, e), |
| Store::make(b_buf, {x}, e + 1)})); |
| auto parent_block = Block::make({f}); |
| StmtPtr unrolled = nullptr; |
| LoopNest::fullUnroll(f, &unrolled); |
| std::ostringstream oss; |
| oss << *unrolled; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: int e = 7; |
| # CHECK: A[0] = e; |
| # CHECK: B[0] = e + 1; |
| # CHECK: A[1] = e; |
| # CHECK: B[1] = e + 1; |
| # CHECK: A[2] = e; |
| # CHECK: B[2] = e + 1;)IR"; |
| |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| std::vector<int> a_v(kTotalSize, 0); |
| std::vector<int> b_v(kTotalSize, 0); |
| SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); |
| eval(a_v, b_v); |
| for (int i = 0; i < kTotalSize; ++i) { |
| ASSERT_EQ(a_v[i], 7); |
| ASSERT_EQ(b_v[i], 8); |
| } |
| } |
| |
| TEST(LoopNest, IsNormalized) { |
| // Input IR: |
| // for (int i = 50; i < 100; i++) { |
| // A[i] = B[i]; |
| // } |
| BufHandle a_buf("A", {ExprHandle(100)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(100)}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto for_stmt = |
| For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); |
| Block::make({for_stmt}); |
| ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); |
| |
| for_stmt->set_start(alloc<IntImm>(0)); |
| ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); |
| |
| VarHandle N("N", kInt); |
| for_stmt->set_start(N.node()); |
| ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); |
| } |
| |
| TEST(LoopNest, NormalizeStartPositive) { |
| // Input IR: |
| // for (int x = 50; x < 100; x++) { |
| // A[x] = B[x]; |
| // B[x] = x * 2; |
| // } |
| const int kTotalSize = 50; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| VarHandle x("x", kInt); |
| auto for_body = Block::make( |
| {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
| Store::make(b_buf, {x}, x * 2)}); |
| auto for_stmt = For::make(x, 50, 100, for_body); |
| Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 0; x < 50; x++) { |
| # CHECK: A[x + 50] = B[x + 50]; |
| # CHECK: B[x + 50] = 2 * (x + 50); |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeStartNegative) { |
| // Input IR: |
| // for (int x = -50; x < 100; x++) { |
| // A[x + 50] = B[x + 50]; |
| // B[x + 50] = x * 2; |
| // } |
| const int kTotalSize = 150; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| VarHandle x("x", kInt); |
| auto for_body = Block::make( |
| {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), |
| Store::make(b_buf, {x + 50}, x * 2)}); |
| auto for_stmt = For::make(x, -50, 100, for_body); |
| Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 0; x < 150; x++) { |
| # CHECK: A[x] = B[x]; |
| # CHECK: B[x] = 2 * (x - 50); |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeStartZero) { |
| // Input IR: |
| // for (int x = 0; x < 100; x++) { |
| // A[x] = B[x]; |
| // B[x] = x * 2; |
| // } |
| // Should not be modified. |
| |
| const int kTotalSize = 100; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| VarHandle x("x", kInt); |
| auto for_body = Block::make( |
| {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
| Store::make(b_buf, {x}, x * 2)}); |
| auto for_stmt = For::make(x, 0, 100, for_body); |
| Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 0; x < 100; x++) { |
| # CHECK: A[x] = B[x]; |
| # CHECK: B[x] = 2 * x; |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeStartVariable) { |
| // Input IR: |
| // for (int x = y; x < 100; x++) { |
| // A[x] = B[x]; |
| // B[x] = x * 2; |
| // } |
| |
| const int kTotalSize = 100; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto for_body = Block::make( |
| {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), |
| Store::make(b_buf, {x}, x * 2)}); |
| auto for_stmt = For::make(x, y, 100, for_body); |
| auto parent_block = Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 0; x < 100 - y; x++) { |
| # CHECK: A[x + y] = B[x + y]; |
| # CHECK: B[x + y] = 2 * (x + y); |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeOnNestedOuterLoop) { |
| // Input IR: |
| // for (int x = 50; x < 100; x++) { |
| // for (int y = 10; y < 100; y++) { |
| // A[x] = A[x] + B[y] + y * 2; |
| // } |
| // } |
| |
| BufHandle a_buf("A", {ExprHandle(50)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(100)}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto inner_for_body = Store::make( |
| a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); |
| auto inner_for = For::make(y, 10, 100, inner_for_body); |
| auto for_stmt = For::make(x, 50, 100, inner_for); |
| Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 0; x < 50; x++) { |
| # CHECK: for (int y = 10; y < 100; y++) { |
| # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeOnNestedInnerLoop) { |
| // Input IR: |
| // for (int x = 50; x < 100; x++) { |
| // for (int y = 10; y < 100; y++) { |
| // A[x] = A[x] + B[y] + y * 2; |
| // } |
| // } |
| |
| BufHandle a_buf("A", {ExprHandle(50)}, kInt); |
| BufHandle b_buf("B", {ExprHandle(100)}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto inner_for_body = Store::make( |
| a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); |
| auto inner_for = For::make(y, 10, 100, inner_for_body); |
| auto for_stmt = For::make(x, 50, 100, inner_for); |
| Block::make({for_stmt}); |
| |
| LoopNest::normalize(inner_for); |
| |
| auto result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int x = 50; x < 100; x++) { |
| # CHECK: for (int y = 0; y < 90; y++) { |
| # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| } |
| |
| TEST(LoopNest, NormalizeAndSplitWithTail) { |
| // Create a dummy tensor to construct LoopNest. |
| ExprHandle n(100); |
| BufHandle a("a", {n}, kFloat); |
| Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); |
| LoopNest l({b}); |
| |
| // Input IR: |
| // for (int x = 5; x < 10; x++) { |
| // A[x] = x * 2; |
| // } |
| const int kTotalSize = 5; |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); |
| VarHandle x("x", kInt); |
| auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); |
| auto parent_block = Block::make({for_stmt}); |
| |
| LoopNest::normalize(for_stmt); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr x_inner; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr x_tail; |
| LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); |
| |
| auto x_outer_result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss_outer; |
| oss_outer << *x_outer_result; |
| const std::string& expected_outer_ir = |
| R"IR( |
| # CHECK: { |
| # CHECK: } |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); |
| |
| auto x_tail_result = IRSimplifier::simplify(x_tail); |
| std::ostringstream oss_tail; |
| oss_tail << *x_tail_result; |
| const std::string& expected_tail_ir = |
| R"IR( |
| # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { |
| # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); |
| } |
| |
| TEST(LoopNest, NotNormalizeAndSplitWithTail) { |
| // Create a dummy tensor to construct LoopNest. |
| ExprHandle n(100); |
| BufHandle a("a", {n}, kFloat); |
| Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); |
| LoopNest l({b}); |
| |
| // Input IR: |
| // for (int x = 5; x < 15; x++) { |
| // A[x] = x * 2; |
| // } |
| const int kTotalSize = 10; |
| BufHandle a_buf("A", {kTotalSize}, kInt); |
| VarHandle x("x", kInt); |
| auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); |
| auto parent_block = Block::make({for_stmt}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr x_inner; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr x_tail; |
| LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); |
| |
| auto x_outer_result = IRSimplifier::simplify(for_stmt); |
| std::ostringstream oss_outer; |
| oss_outer << *x_outer_result; |
| const std::string& expected_outer_ir = |
| R"IR( |
| # CHECK: { |
| # CHECK: } |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); |
| |
| auto x_tail_result = IRSimplifier::simplify(x_tail); |
| std::ostringstream oss_tail; |
| oss_tail << *x_tail_result; |
| const std::string& expected_tail_ir = |
| R"IR( |
| # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { |
| # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); |
| } |
| |
| TEST(LoopNest, FlattenSimpleLoopNest2D) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // for (int j = 0; j < 5; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 0, 5, for_body); |
| auto outer_for = For::make(i, 0, 10, inner_for); |
| auto parent_block = Block::make({outer_for}); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| ForPtr flattened = nullptr; |
| ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, loops.front()); |
| |
| auto result = IRSimplifier::simplify(flattened); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { |
| # CHECK: A[i_flat / 5, i_flat % 5] = |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| |
| { |
| SimpleIREvaluator eval1(loops[0], {a_buf}); |
| PaddedBuffer<int> inp1(10, 5); |
| eval1(inp1); |
| SimpleIREvaluator eval2(flattened, {a_buf}); |
| PaddedBuffer<int> inp2(10, 5); |
| eval2(inp2); |
| ExpectAllNear(inp1, inp2, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, FlattenSimpleLoopNest3D) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // for (int j = 0; j < 5; j++) { |
| // for (int k = 0; k < 7; k++) { |
| // A[i,j,k] = i + j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {10, 5, 7}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); |
| auto for1 = For::make(k, 0, 7, for_body); |
| auto for2 = For::make(j, 0, 5, for1); |
| auto for3 = For::make(i, 0, 10, for2); |
| auto parent_block = Block::make({for3}); |
| |
| std::vector<ForPtr> loops = {for3, for2, for1}; |
| ForPtr flattened = nullptr; |
| ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, loops.front()); |
| |
| auto result = IRSimplifier::simplify(flattened); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { |
| # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| |
| { |
| SimpleIREvaluator eval1(loops[0], {a_buf}); |
| PaddedBuffer<int> inp1(10, 5, 7); |
| eval1(inp1); |
| SimpleIREvaluator eval2(flattened, {a_buf}); |
| PaddedBuffer<int> inp2(10, 5, 7); |
| eval2(inp2); |
| ExpectAllNear(inp1, inp2, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, FlattenLoopNestAfterNormalize) { |
| // Input IR: |
| // for (int i = 2; i < 10; i++) { |
| // for (int j = 3; j < 15; j++) { |
| // A[i - 2,j - 3] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {8, 12}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); |
| auto inner_for = For::make(j, 3, 15, for_body); |
| auto outer_for = For::make(i, 2, 10, inner_for); |
| auto parent_block = Block::make({outer_for}); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| ForPtr flattened = nullptr; |
| ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, loops.front()); |
| |
| auto result = IRSimplifier::simplify(flattened); |
| std::ostringstream oss; |
| oss << *result; |
| const std::string& expected_ir = |
| R"IR( |
| # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { |
| # CHECK: A[i_flat / 12, i_flat % 12] = |
| )IR"; |
| torch::jit::testing::FileCheck().run(expected_ir, oss.str()); |
| |
| { |
| SimpleIREvaluator eval1(loops[0], {a_buf}); |
| PaddedBuffer<int> inp1(8, 12); |
| eval1(inp1); |
| SimpleIREvaluator eval2(flattened, {a_buf}); |
| PaddedBuffer<int> inp2(8, 12); |
| eval2(inp2); |
| ExpectAllNear(inp1, inp2, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { |
| // Input IR: |
| // for (int i = 0; i < 15-5; i++) { |
| // for (int j = 0; j < 20/4; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = |
| For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); |
| auto outer_for = |
| For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto b = Block::make({outer_for}); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| ForPtr flattened = nullptr; |
| ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, loops.front()); |
| |
| auto result = IRSimplifier::simplify(flattened); |
| checkIR(result, R"IR( |
| # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { |
| # CHECK: A[i_flat / 5, i_flat % 5] = |
| )IR"); |
| |
| { |
| SimpleIREvaluator eval1(loops[0], {a_buf}); |
| PaddedBuffer<int> inp1(10, 5); |
| eval1(inp1); |
| SimpleIREvaluator eval2(flattened, {a_buf}); |
| PaddedBuffer<int> inp2(10, 5); |
| eval2(inp2); |
| ExpectAllNear(inp1, inp2, 1e-5); |
| } |
| } |
| |
| TEST(LoopNest, FlattenImperfectLoopNest) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // A[i, i] = 0; |
| // for (int j = 0; j < 15; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| // Do not flatten. |
| |
| BufHandle a_buf("A", {10, 15}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for = For::make(j, 0, 15, for_body); |
| auto outer_for = For::make( |
| i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); |
| auto par = Block::make({outer_for}); |
| HashProvider hasher; |
| auto hash_before = hasher.hash(par); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| ForPtr flattened = nullptr; |
| ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, nullptr); |
| auto hash_after = hasher.hash(par); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, FlattenReductionLoopNest) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // S[i] = 0; |
| // for (int j = 0; j < 15; j++) { |
| // S[i] = S[i] + A[i,j]; |
| // } |
| // } |
| // Do not flatten. |
| |
| BufHandle a_buf("A", {10, 15}, kInt); |
| BufHandle s_buf("S", {10}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto for_body = Block::make({Store::make( |
| s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); |
| auto inner_for = For::make(j, 0, 15, for_body); |
| auto outer_for = |
| For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); |
| auto par = Block::make({outer_for}); |
| HashProvider hasher; |
| auto hash_before = hasher.hash(par); |
| |
| std::vector<ForPtr> loops = {outer_for, inner_for}; |
| ForPtr flattened = nullptr; |
| ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, nullptr); |
| auto hash_after = hasher.hash(par); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, FlattenReductionLoopNestFromTensor) { |
| const int M = 3; |
| const int N = 7; |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| BufHandle b("b", {m, n}, kFloat); |
| Tensor c = Reduce("sum", {M}, Sum(), b, {N}); |
| LoopNest loop({c}); |
| HashProvider hasher; |
| auto hash_before = hasher.hash(loop.root_stmt()); |
| |
| auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; |
| ForPtr flattened = nullptr; |
| ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, nullptr); |
| auto hash_after = hasher.hash(loop.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, FlattenIncorrectLoopsAsInput) { |
| // Input IR: |
| // for (int i = 0; i < 10; i++) { |
| // for (int j = 0; j < 5; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| // for (int x = 0; x < 10; x++) { |
| // for (int y = 0; y < 5; y++) { |
| // A[x,y] = A[x,y] + x + y; |
| // } |
| // } |
| // Flatten({For_i, For_y}) => should not succeed |
| |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for1 = For::make(j, 0, 5, for_body1); |
| auto outer_for1 = For::make(i, 0, 10, inner_for1); |
| auto for_body2 = Block::make( |
| {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
| auto inner_for2 = For::make(y, 0, 5, for_body2); |
| auto outer_for2 = For::make(x, 0, 10, inner_for2); |
| auto par = Block::make({outer_for1, outer_for2}); |
| HashProvider hasher; |
| auto hash_before = hasher.hash(par); |
| |
| std::vector<ForPtr> loops = {outer_for1, inner_for2}; |
| ForPtr flattened = nullptr; |
| ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); |
| ASSERT_EQ(flattened, nullptr); |
| auto hash_after = hasher.hash(par); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, DetectInlineRankMismatch) { |
| const int kTotalSize = 8; |
| |
| BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); |
| Tensor a = Compute( |
| "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); |
| Tensor reshape = Compute( |
| "reshape", |
| {kTotalSize / 2, 2}, |
| [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); |
| LoopNest l({reshape}, {a, reshape}); |
| ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); |
| } |
| |
| TEST(LoopNest, CacheReadsSimple) { |
| Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| Tensor B = |
| Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 30, j + 3); |
| }); |
| Tensor C = |
| Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
| }); |
| |
| LoopNest l({B, C}, {A, B, C}); |
| StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; |
| LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); |
| |
| l.prepareForCodegen(); |
| StmtPtr result = |
| LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
| SimpleIREvaluator cg(result, {B, C}); |
| result = cg.stmt(); |
| |
| // just this once: verify the whole thing. |
| checkIR(result, R"IR( |
| #CHECK: Allocate(A); // dtype=int, dims=[64, 64] |
| #CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] |
| #CHECK: for (int i |
| #CHECK: for (int j |
| #CHECK: A[ |
| #CHECK: } |
| #CHECK: } |
| #CHECK: for (int i_1 |
| #CHECK: for (int j_1 |
| #CHECK: A_local[j_1] = A[ |
| #CHECK: } |
| #CHECK: for (int j_2 |
| #CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; |
| #CHECK: } |
| #CHECK: } |
| #CHECK: for (int i_2 |
| #CHECK: for (int j_3 |
| #CHECK: C[ |
| #CHECK: } |
| #CHECK: } |
| #CHECK: Free(A_local); |
| #CHECK: Free(A); |
| )IR"); |
| |
| std::vector<int> b_data(200, 0); |
| std::vector<int> c_data(200, 0); |
| cg.call({b_data, c_data}); |
| |
| std::vector<int> b_ref(200, 0); |
| std::vector<int> c_ref(200, 0); |
| |
| for (int i = 0; i < 20; ++i) { |
| for (int j = 0; j < 10; ++j) { |
| b_ref[i * 10 + j] = (i + 30) * (j + 3); |
| c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
| } |
| } |
| |
| assertAllEqual(b_data, b_ref); |
| assertAllEqual(c_data, c_ref); |
| } |
| |
| TEST(LoopNest, CacheReadsOuter) { |
| Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| Tensor B = |
| Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
| }); |
| Tensor C = |
| Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
| }); |
| |
| LoopNest l({B, C}, {A, B, C}); |
| StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; |
| LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); |
| |
| l.prepareForCodegen(); |
| StmtPtr result = |
| LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
| SimpleIREvaluator cg(result, {B, C}); |
| result = cg.stmt(); |
| |
| checkIR(result, R"IR( |
| #CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] |
| #CHECK: A_local[j_1 + 11 * i_1] = |
| #CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); |
| )IR"); |
| |
| std::vector<int> b_data(200, 0); |
| std::vector<int> c_data(200, 0); |
| cg.call({b_data, c_data}); |
| |
| std::vector<int> b_ref(200, 0); |
| std::vector<int> c_ref(200, 0); |
| |
| for (int i = 0; i < 20; ++i) { |
| for (int j = 0; j < 10; ++j) { |
| b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
| c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
| } |
| } |
| |
| assertAllEqual(b_data, b_ref); |
| assertAllEqual(c_data, c_ref); |
| } |
| |
| TEST(LoopNest, CacheReadsInternal) { |
| Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| Tensor B = |
| Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
| }); |
| Tensor C = |
| Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
| }); |
| |
| LoopNest l({B, C}, {A, B, C}); |
| StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; |
| LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); |
| l.prepareForCodegen(); |
| StmtPtr result = |
| LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
| SimpleIREvaluator cg(result, {B, C}); |
| result = cg.stmt(); |
| |
| checkIR(result, R"IR( |
| #CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] |
| #CHECK: A_local[k + 11 * j_1] = |
| #CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); |
| )IR"); |
| |
| std::vector<int> b_data(200, 0); |
| std::vector<int> c_data(200, 0); |
| cg.call({b_data, c_data}); |
| |
| std::vector<int> b_ref(200, 0); |
| std::vector<int> c_ref(200, 0); |
| |
| for (int i = 0; i < 20; ++i) { |
| for (int j = 0; j < 10; ++j) { |
| b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
| c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
| } |
| } |
| |
| assertAllEqual(b_data, b_ref); |
| assertAllEqual(c_data, c_ref); |
| } |
| |
| TEST(LoopNest, CacheReadsInner) { |
| Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| // note im changing the offset of the first arg of the first call to A. |
| Tensor B = |
| Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); |
| }); |
| Tensor C = |
| Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
| }); |
| |
| LoopNest l({B, C}, {A, B, C}); |
| StmtPtr body = l.getLoopBodyFor(B); |
| LoopNest::cacheAccesses(A.buf(), "A_local", body); |
| l.prepareForCodegen(); |
| StmtPtr result = |
| LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
| SimpleIREvaluator cg(result, {B, C}); |
| result = cg.stmt(); |
| |
| checkIR(result, R"IR( |
| #CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] |
| #CHECK: A_local[l + 2 * k] = |
| #CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); |
| )IR"); |
| |
| std::vector<int> b_data(200, 0); |
| std::vector<int> c_data(200, 0); |
| cg.call({b_data, c_data}); |
| |
| std::vector<int> b_ref(200, 0); |
| std::vector<int> c_ref(200, 0); |
| |
| for (int i = 0; i < 20; ++i) { |
| for (int j = 0; j < 10; ++j) { |
| b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); |
| c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
| } |
| } |
| |
| assertAllEqual(b_data, b_ref); |
| assertAllEqual(c_data, c_ref); |
| } |
| |
| TEST(LoopNest, CacheWritesSimple) { |
| Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { |
| return i * j; |
| }); |
| Tensor B = |
| Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); |
| }); |
| Tensor C = |
| Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); |
| }); |
| |
| LoopNest l({B, C}, {A, B, C}); |
| StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; |
| LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); |
| |
| l.prepareForCodegen(); |
| StmtPtr result = |
| LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); |
| SimpleIREvaluator cg(result, {B, C}); |
| result = cg.stmt(); |
| |
| checkIR(result, R"IR( |
| #CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] |
| #CHECK: for (int j = 0; j < 64 |
| #CHECK: A_local[j] = i * j; |
| #CHECK: for (int j_1 = 0; j_1 < 64 |
| #CHECK: A[j_1 + 64 * i] = A_local[ |
| #CHECK: Free(A_local); |
| #CHECK-NOT: A_local |
| )IR"); |
| |
| std::vector<int> b_data(200, 0); |
| std::vector<int> c_data(200, 0); |
| cg.call({b_data, c_data}); |
| |
| std::vector<int> b_ref(200, 0); |
| std::vector<int> c_ref(200, 0); |
| |
| for (int i = 0; i < 20; ++i) { |
| for (int j = 0; j < 10; ++j) { |
| b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); |
| c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); |
| } |
| } |
| |
| assertAllEqual(b_data, b_ref); |
| assertAllEqual(c_data, c_ref); |
| } |
| |
| TEST(LoopNest, DeadStoreElimination) { |
| VarHandle y("y", kInt); |
| VarHandle x("x_tail", kInt); |
| BufHandle f("f", {26, 5}, kInt); |
| BufHandle g("g", {26, 5}, kInt); |
| ExprHandle x_outer_end = 5; |
| ExprHandle x_2 = x + x_outer_end * 4; |
| ForPtr stmt1 = For::make( |
| x, |
| 0, |
| 5, |
| For::make( |
| y, |
| 0, |
| 5, |
| Block::make({ |
| Store::make(f, {x_2, y}, (x_2 + y)), |
| Store::make(g, {x_2, y}, (x_2 * y)), |
| }))); |
| StmtPtr stmt = Block::make({stmt1}); |
| |
| // Will eliminate if not used by an output. |
| LoopNest loop(Stmt::clone(stmt), {f.node()}); |
| loop.eliminateDeadStores(); |
| |
| checkIR(loop.root_stmt(), R"IR( |
| #CHECK: f[x_tail + 5 * 4, y] |
| #CHECK-NOT: g[x_tail + 5 * 4, y] |
| )IR"); |
| |
| // But won't eliminate if used by different outputs. |
| LoopNest loop2(stmt, {f.node(), g.node()}); |
| loop2.eliminateDeadStores(); |
| |
| checkIR(loop2.root_stmt(), R"IR( |
| #CHECK: f[x_tail + 5 * 4, y] |
| #CHECK: g[x_tail + 5 * 4, y] |
| )IR"); |
| } |
| |
| TEST(LoopNest, DeadStoreEliminationWithIntermediates) { |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| VarHandle z("z", kInt); |
| BufHandle f("f", {26 * 5}, kInt); |
| BufHandle g("g", {26 * 5}, kInt); |
| BufHandle h("h", {26, 5}, kInt); |
| ExprHandle x_outer_end = 5; |
| ExprHandle x_2 = x + x_outer_end * 4; |
| ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); |
| ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); |
| ForPtr stmt3 = For::make( |
| x, |
| 0, |
| 5, |
| For::make( |
| y, |
| 0, |
| 5, |
| Block::make({ |
| Store::make(h, {x, y}, Load::make(f, {x * y})), |
| }))); |
| StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); |
| |
| // Will eliminate the write to g, but not f since it used by the producer of |
| // h. |
| LoopNest loop(Stmt::clone(stmt), {h.node()}); |
| loop.eliminateDeadStores(); |
| |
| checkIR(loop.root_stmt(), R"IR( |
| #CHECK: f[x] = x; |
| #CHECK-NOT: g[z] = |
| #CHECK: h[x, y] = f[x * y]; |
| )IR"); |
| |
| // Sanity check won't eliminate if g is an output. |
| LoopNest loop2(stmt, {h.node(), g.node()}); |
| loop2.eliminateDeadStores(); |
| |
| checkIR(loop2.root_stmt(), R"IR( |
| #CHECK: f[x] = x; |
| #CHECK: g[z] = z + 1; |
| #CHECK: h[x, y] = f[x * y]; |
| )IR"); |
| } |
| |
| TEST(LoopNest, CompoundTensorSimple) { |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for1 = For::make(j, 0, 5, for_body1); |
| auto outer_for1 = For::make(i, 0, 10, inner_for1); |
| auto for_body2 = Block::make( |
| {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
| auto inner_for2 = For::make(y, 0, 5, for_body2); |
| auto outer_for2 = For::make(x, 0, 10, inner_for2); |
| BlockPtr body = Block::make({outer_for1, outer_for2}); |
| |
| Tensor A = Tensor(a_buf.node(), body); |
| |
| LoopNest l({A}); |
| l.prepareForCodegen(); |
| |
| std::vector<int> a_data(50, 0); |
| |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| SimpleIREvaluator cg(s, {A}); |
| |
| std::vector<int> a_ref(50, 0); |
| |
| for (int i = 0; i < 10; ++i) { |
| for (int j = 0; j < 5; ++j) { |
| a_ref[i * 5 + j] = (i * j) + i + j; |
| } |
| } |
| cg.call({a_data}); |
| |
| assertAllEqual(a_data, a_ref); |
| } |
| |
| TEST(LoopNest, InlineConstantIndex) { |
| const int N = 10; |
| BufHandle x_buf("a", {1, N, 1}, kFloat); |
| Tensor y = Compute( |
| "f", |
| {1, N, 1}, |
| [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { |
| return x_buf.load(m, n, o); |
| }); |
| Tensor z = Compute( |
| "f", |
| {1, N, 1}, |
| [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { |
| return y.load(m, n, o); |
| }); |
| |
| LoopNest l({z}, {y, z}); |
| l.simplify(); |
| ASSERT_TRUE(l.computeInline(y.buf())); |
| } |
| |
| TEST(LoopNest, CompoundTensorUsed) { |
| BufHandle a_buf("A", {10, 5}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); |
| auto inner_for1 = For::make(j, 0, 5, for_body1); |
| auto outer_for1 = For::make(i, 0, 10, inner_for1); |
| auto for_body2 = Block::make( |
| {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); |
| auto inner_for2 = For::make(y, 0, 5, for_body2); |
| auto outer_for2 = For::make(x, 0, 10, inner_for2); |
| BlockPtr body = Block::make({outer_for1, outer_for2}); |
| |
| Tensor A = Tensor(a_buf.node(), body); |
| Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { |
| return A.load(i, j + 1) + A.load(i, j + 2); |
| }); |
| |
| LoopNest l({B}, {A, B}); |
| ASSERT_FALSE(l.computeInline(A.buf())); |
| l.prepareForCodegen(); |
| |
| std::vector<int> a_data(50, 0); |
| std::vector<int> b_data(50, 0); |
| |
| StmtPtr s = IRSimplifier::simplify(l.root_stmt()); |
| SimpleIREvaluator cg(s, {B}); |
| |
| std::vector<int> b_ref(50, 0); |
| |
| auto AT = [](int i, int j) { return i * j + i + j; }; |
| for (int i = 0; i < 10; ++i) { |
| for (int j = 0; j < 3; ++j) { |
| b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); |
| } |
| } |
| cg.call({b_data}); |
| |
| assertAllEqual(b_data, b_ref); |
| } |
| |
| TEST(LoopNest, InlineFromLoad) { |
| constexpr int N = 1024; |
| BufHandle a("A", {N}, kInt); |
| BufHandle b("B", {N}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); |
| auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); |
| LoopNest l(Block::make({store_a, store_b}), {b.node()}); |
| |
| l.computeInline(a.node()); |
| |
| // Check that A[j] is replaced with j after inlining |
| std::ostringstream oss; |
| oss << *l.root_stmt(); |
| torch::jit::testing::FileCheck().run( |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NOT: B[j] = A[j] |
| # CHECK-NEXT: B[j] = j |
| )IR", |
| oss.str()); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsSimple) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
| // } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {15}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| |
| LoopNest nest(par, {a_buf.node()}); |
| nest.optimizeConditionals(); |
| |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i = 0; i < 5 |
| # CHECK-NEXT: A[i] = B[i] |
| # CHECK: for (int i = 0; i < 15 |
| # CHECK-NEXT: A[i + 5] = C[i] |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsNestedConditions) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
| // } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 10, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| |
| LoopNest nest(par, {a_buf.node()}); |
| nest.optimizeConditionals(); |
| |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i = 0; i < 5 |
| # CHECK-NEXT: A[i] = B[i] |
| # CHECK: for (int i = 0; i < 5 |
| # CHECK-NEXT: A[i + 5] = C[i] |
| # CHECK: for (int i = 0; i < 10 |
| # CHECK-NEXT: A[i + 10] = D[i] |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsMultipleStores) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
| // } |
| // for (int j = 0; j < 100; j++) { |
| // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) |
| // } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {100}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto storeA = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, storeA); |
| auto storeB = Store::make( |
| b_buf, |
| {j}, |
| IfThenElse::make( |
| CompareSelect::make(j, 30, kLT), |
| Load::make(c_buf, {j}), |
| Load::make(d_buf, {j}))); |
| auto forJ = For::make(j, 0, 100, storeB); |
| auto par = Block::make({forI, forJ}); |
| |
| LoopNest nest(par, {a_buf.node()}); |
| nest.optimizeConditionals(); |
| |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i = 0; i < 5 |
| # CHECK-NEXT: A[i] = B[i] |
| # CHECK: for (int i = 0; i < 15 |
| # CHECK-NEXT: A[i + 5] = C[i] |
| # CHECK: for (int j = 0; j < 30 |
| # CHECK-NEXT: B[j] = C[j] |
| # CHECK: for (int j = 0; j < 70 |
| # CHECK-NEXT: B[j + 30] = D[j + 30] |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { |
| // Input IR: |
| // for (int i = 0; i < 50; i++) { |
| // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
| // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) |
| // } |
| // Only the first conditional, in the write to A, will be optimized. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {100}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {100}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {100}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {100}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto storeA = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5}))); |
| auto storeB = Store::make( |
| b_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 30, kLT), |
| Load::make(c_buf, {i}), |
| Load::make(d_buf, {i}))); |
| auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); |
| auto par = Block::make({forI}); |
| |
| LoopNest nest(par, {a_buf.node()}); |
| nest.optimizeConditionals(); |
| |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i = 0; i < 5 |
| # CHECK-NEXT: A[i] = B[i] |
| # CHECK-NEXT: B[i] = C[i] |
| # CHECK: for (int i = 0; i < 45 |
| # CHECK-NEXT: A[i + 5] = C[i] |
| # CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
| // } |
| // } |
| // Currently, this case where the condition variable `i` is not the |
| // inner-most loop variable, is not optimized. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 10, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because one of the conditions use '>'. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 10, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<N, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because one of the conditions use '>'. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle N("N", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, N, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsInvalidCondition) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because one of the conditions use '>'. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 10, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kGT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(10<i, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because of the invalid condition: |
| // "10 < i". |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(10, i, kLT), |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsInvalidCondition3) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(i<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because the conditions use different |
| // variables: "i < 10" and "k < 5" |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle k("k", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 10, kLT), |
| IfThenElse::make( |
| CompareSelect::make(k, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsInvalidCondition4) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = IfThenElse(k<10, IfThenElse(k<5, B[i], C[i-5]), D[i-10]) |
| // } |
| // No optimization should be done here because the conditions use the |
| // variable 'k' which is not a loop variable. |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle d_buf("D", {10}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle k("k", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(k, 10, kLT), |
| IfThenElse::make( |
| CompareSelect::make(k, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5})), |
| Load::make(d_buf, {i - 10}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| TEST(LoopNest, OptimizeConditionalsNotNormalized) { |
| // Input IR: |
| // for (int i = 2; i < 20; i++) { |
| // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) |
| // } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle a_buf("A", {20}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle b_buf("B", {5}, kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| BufHandle c_buf("C", {15}, kInt); |
| VarHandle i("i", kInt); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto store = Store::make( |
| a_buf, |
| {i}, |
| IfThenElse::make( |
| CompareSelect::make(i, 5, kLT), |
| Load::make(b_buf, {i}), |
| Load::make(c_buf, {i - 5}))); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 2, 20, store); |
| auto par = Block::make({forI}); |
| LoopNest nest(par, {a_buf.node()}); |
| |
| HashProvider hasher; |
| auto hash_before = hasher.hash(nest.root_stmt()); |
| nest.optimizeConditionals(); |
| auto hash_after = hasher.hash(nest.root_stmt()); |
| ASSERT_EQ(hash_before, hash_after); |
| } |
| |
| static std::pair<BufHandle, Tensor> colReduce(int M, int N) { |
| BufHandle a("a", {M, N}, kFloat); |
| Tensor t = Reduce( |
| "b", |
| {N}, |
| Sum(), |
| [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, |
| {M}); |
| return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; |
| } |
| |
| static StmtPtr splitTailReorder(Tensor b) { |
| constexpr int kVectorWidth = 8; |
| LoopNest nest({b}); |
| auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; |
| nest.splitWithTail(loops[0], kVectorWidth); |
| // Now the loopnests will look like: |
| // |
| // for (int i_outer = 0; ... |
| // for (int i_inner = 0; ... |
| // b[i_outer * 8 + i_inner] = float(0); |
| // for (int j = 0; ... |
| // b[i_outer * 8 + i_inner] = ReduceOp(...); |
| // |
| // for (int i_tail = 0; ... |
| // b[i_tail + ((100 - 0) / 8) * 8] = float(0); |
| // for (int j = 0; ... |
| // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); |
| // |
| // Since there are 4 writes to b, we will get 4 loopnests from the |
| // call to `getAllLoopNestsWritingToBuf` below. |
| // |
| // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" |
| // Loopnest #2: {i_outer, i_inner, j}; |
| // We will have to reorder i_inner and j. |
| auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); |
| LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); |
| nest.prepareForCodegen(); |
| return nest.root_stmt(); |
| } |
| |
| static StmtPtr splitMaskReorder(Tensor b) { |
| constexpr int kVectorWidth = 8; |
| LoopNest nest({b}); |
| auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; |
| nest.splitWithMask(loops[0], kVectorWidth); |
| loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; |
| LoopNest::reorderAxis(loops[1], loops[2]); |
| nest.prepareForCodegen(); |
| return nest.root_stmt(); |
| } |
| |
| static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { |
| int M = immediateAs<int>(p.dim(0)); |
| int N = immediateAs<int>(p.dim(1)); |
| PaddedBuffer<float> a(M, N); |
| PaddedBuffer<float> b(N); |
| PaddedBuffer<float> ref(N); |
| for (int i = 0; i < M; i++) { |
| for (int j = 0; j < N; j++) { |
| a(i, j) = 1.0f; |
| } |
| } |
| for (int i = 0; i < N; i++) { |
| b(i) = 0.0f; |
| } |
| for (int i = 0; i < N; i++) { |
| ref(i) = 76.0f; |
| } |
| SimpleIREvaluator(s, {p, t}).call({a, b}); |
| ExpectAllNear(b, ref, 1e-5); |
| } |
| |
| TEST(LoopNest, ColReduceSplitTailEvenReorder) { |
| constexpr int M = 76, N = 128; |
| auto p = colReduce(M, N); |
| StmtPtr s = splitTailReorder(p.second); |
| |
| std::ostringstream oss; |
| oss << *s; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i_outer |
| # CHECK-NEXT: for (int i_inner |
| # CHECK-NEXT: b[ |
| # CHECK: for (int j |
| # CHECK-NEXT: for (int i_inner |
| # CHECK-NEXT: b[ |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| checkColReduce(s, p.first, p.second); |
| } |
| |
| TEST(LoopNest, ColReduceSplitTailUnevenReorder) { |
| constexpr int M = 76, N = 100; |
| auto p = colReduce(M, N); |
| StmtPtr s = splitTailReorder(p.second); |
| |
| std::ostringstream oss; |
| oss << *s; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i_outer |
| # CHECK-NEXT: for (int i_inner |
| # CHECK-NEXT: b[ |
| # CHECK: for (int j |
| # CHECK-NEXT: for (int i_inner |
| # CHECK-NEXT: b[ |
| # CHECK: for (int i_tail |
| # CHECK-NEXT: b[ |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: b[ |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| checkColReduce(s, p.first, p.second); |
| } |
| |
| TEST(LoopNest, ColReduceSplitMaskEvenReorder) { |
| constexpr int M = 76, N = 128; |
| auto p = colReduce(M, N); |
| StmtPtr s = splitMaskReorder(p.second); |
| checkColReduce(s, p.first, p.second); |
| } |
| |
| TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { |
| constexpr int M = 76, N = 100; |
| auto p = colReduce(M, N); |
| StmtPtr s = splitMaskReorder(p.second); |
| checkColReduce(s, p.first, p.second); |
| } |
| |
| TEST(LoopNest, ReorderAxisWithMultipleConds) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // if i > 5 { |
| // if i < 10 { |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = i * j; |
| // } |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); |
| auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); |
| auto outer_cond = |
| Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); |
| auto forI = For::make(i, 0, 20, outer_cond); |
| StmtPtr par = Block::make({forI}); |
| LoopNest l(par, {a_buf.node()}); |
| LoopNest::reorderAxis(forI, forJ); |
| ASSERT_EQ(par, l.root_stmt()); |
| par = IRSimplifier::simplify(par); |
| |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: if (i>5 |
| # CHECK-NEXT: if (i<10 |
| # CHECK-NEXT: A[i] = i * j |
| # CHECK-NOT: for ( |
| )IR"; |
| std::ostringstream oss; |
| oss << *par; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| TEST(LoopNest, VectorizeUse) { |
| constexpr int N = 8; |
| BufHandle a("a", {N}, kFloat); |
| Tensor b = |
| Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); |
| Tensor c = |
| Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); |
| LoopNest nest({c}, {b, c}); |
| auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; |
| ASSERT_TRUE(LoopNest::vectorize(loops[0])); |
| loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; |
| ASSERT_TRUE(LoopNest::vectorize(loops[0])); |
| nest.prepareForCodegen(); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| StmtPtr s = nest.root_stmt(); |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| torch::jit::testing::FileCheck().run( |
| R"IR( |
| # CHECK: c[Ramp |
| )IR", |
| oss.str()); |
| } |
| |
| const char* int64Loop = R"IR( |
| # CHECK: for (int64_t i = 0ll; i < 12ll; i++) { |
| # CHECK: b[i] = (a[i]) + 1ll; |
| # CHECK: } |
| )IR"; |
| |
| TEST(LoopNest, Int64Direct) { |
| constexpr int64_t N = 12; |
| BufHandle a("a", {N}, kLong); |
| BufHandle b("b", {N}, kLong); |
| VarHandle n("i", kLong); |
| StmtPtr s = For::make( |
| n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); |
| s = IRSimplifier::simplify(s); |
| std::ostringstream oss; |
| oss << *s; |
| torch::jit::testing::FileCheck().run(int64Loop, oss.str()); |
| } |
| |
| TEST(LoopNest, Int64Compute) { |
| constexpr int64_t N = 12; |
| BufHandle a("a", {N}, kLong); |
| Tensor b = Compute("b", {N}, [&](const VarHandle& n) { |
| return a.load(n) + LongImm::make(1l); |
| }); |
| LoopNest nest({b}); |
| nest.prepareForCodegen(); |
| nest.simplify(); |
| std::ostringstream oss; |
| oss << *nest.root_stmt(); |
| torch::jit::testing::FileCheck().run(int64Loop, oss.str()); |
| } |
| |
| TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = A[i] + i * j; |
| // } |
| // B[i] = A[i]; |
| // for (int k = 0; k < 50; k++) { |
| // B[i] = B[i] + i * k; |
| // } |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {i}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
| auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
| auto par = Block::make({forI}); |
| |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i] = 0 |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i] = |
| # CHECK: for (int i |
| # CHECK-NEXT: B[i] = A[i] |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); |
| |
| std::ostringstream oss; |
| oss << *par; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(new_loops.front(), forI); |
| } |
| |
| TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = A[i] + i * j; |
| // } |
| // B[i] = A[i]; |
| // for (int k = 0; k < 50; k++) { |
| // B[i] = B[i] + i * k; |
| // } |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {i}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
| auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
| auto par = Block::make({forI}); |
| |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto new_loops = LoopNest::distributeLoop(forI, {forJ}); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i] = 0 |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i] = |
| # CHECK: for (int i |
| # CHECK-NEXT: B[i] = A[i] |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(new_loops.front(), forI); |
| } |
| |
| TEST(LoopNest, DistributeLoopWithoutAnyPivot) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = A[i] + i * j; |
| // } |
| // B[i] = A[i]; |
| // for (int k = 0; k < 50; k++) { |
| // B[i] = B[i] + i * k; |
| // } |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {i}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
| auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
| auto par = Block::make({forI}); |
| |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i] = 0 |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i] = |
| # CHECK: for (int i |
| # CHECK-NEXT: B[i] = A[i] |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto new_loops = LoopNest::distributeLoop(forI); |
| |
| std::ostringstream oss; |
| oss << *par; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(new_loops.front(), forI); |
| } |
| |
| TEST(LoopNest, DistributeLoopOverInnerLoops) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = A[i] + i * j; |
| // } |
| // B[i] = A[i]; |
| // for (int k = 0; k < 50; k++) { |
| // B[i] = B[i] + i * k; |
| // } |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {i}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); |
| auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
| auto par = Block::make({forI}); |
| |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i] = 0 |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i] = |
| # CHECK: for (int i |
| # CHECK-NEXT: B[i] = A[i] |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(new_loops.front(), forI); |
| } |
| |
| TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { |
| // Input IR: |
| // for (int m = 0; m < 50; m++) { |
| // for (int i = 0; i < 20; i++) { |
| // A[m,i] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[m,i] = A[m,i] + i * j; |
| // } |
| // B[m,i] = A[m,i]; |
| // for (int k = 0; k < 50; k++) { |
| // B[m,i] = B[m,i] + i * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {100, 100}, kInt); |
| BufHandle b_buf("B", {100, 100}, kInt); |
| VarHandle m("m", kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {m, i}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, |
| {m, i}, |
| Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); |
| auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, |
| {m, i}, |
| Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); |
| |
| { |
| // Check the case of distributing loop and its parents over all the |
| // statements in the loop. |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: A[m, i] = 0 |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[m, i] = |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: B[m, i] = A[m, i] |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[m, i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| |
| auto newForI = to<For>(Stmt::clone(forI)); |
| auto forM = For::make(m, 0, 50, newForI); |
| auto par = Block::make({forM}); |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto newLoops = LoopNest::distributeLoopAndParents(newForI); |
| |
| std::ostringstream oss; |
| oss << *par; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(newLoops.front(), forM); |
| } |
| |
| { |
| // Check the case of distributing loop and its parents over all the inner |
| // loops. |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: A[m, i] = 0 |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[m, i] = |
| # CHECK: for (int m |
| # CHECK-NEXT: for (int i |
| # CHECK-NEXT: B[m, i] = A[m, i] |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[m, i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| |
| auto newForI = to<For>(Stmt::clone(forI)); |
| auto forM = For::make(m, 0, 50, newForI); |
| auto par = Block::make({forM}); |
| LoopNest nest(par, {a_buf.node(), b_buf.node()}); |
| auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); |
| |
| std::ostringstream oss; |
| oss << *par; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The first loop after distribution must be same as the original For. |
| ASSERT_EQ(newLoops.front(), forM); |
| } |
| } |
| |
| TEST(LoopNest, fuseLoopsSimple) { |
| // Input IR: |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < 100; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: A[j] = |
| # CHECK-NEXT: B[j] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsMultiple) { |
| // Input IR: |
| // for (int i = 0; i < 100; i++) { |
| // A[i+100] = 20 + i; |
| // } |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < 100; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {200}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forI = |
| For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); |
| auto par = Block::make({forI, forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i + 100] = |
| # CHECK-NEXT: A[i] = |
| # CHECK-NEXT: B[i] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsNested) { |
| // Input IR: |
| // for (int m = 0; m < 20; m++) { |
| // A[m] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[m] = A[m] + m * j; |
| // } |
| // } |
| // for (int n = 0; n < 20; n++) { |
| // B[n] = A[n]; |
| // for (int k = 0; k < 50; k++) { |
| // B[n] = B[n] + n * k; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 100}, kInt); |
| BufHandle b_buf("B", {20, 100}, kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {m}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); |
| auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); |
| auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); |
| auto forN = For::make(n, 0, 20, Block::make({initB, forK})); |
| auto par = Block::make({forM, forN}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int m |
| # CHECK-NEXT: A[m] = 0 |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[m] = |
| # CHECK: B[m] = A[m] |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: B[m] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forM); |
| } |
| |
| TEST(LoopNest, fuseLoopsNested2D) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j * 500; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 50; n++) { |
| // B[m,n] = m + n * 100; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 100}, kInt); |
| BufHandle b_buf("B", {20, 100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto forI = For::make( |
| i, |
| 0, |
| 20, |
| For::make( |
| j, |
| 0, |
| 100, |
| Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); |
| auto forM = For::make( |
| m, |
| 0, |
| 20, |
| For::make( |
| n, |
| 0, |
| 50, |
| Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); |
| auto par = Block::make({forI, forM}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, j] = |
| # CHECK: for (int n |
| # CHECK-NEXT: B[i, n] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsNested2DInner) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j * 500; |
| // } |
| // for (int n = 0; n < 100; n++) { |
| // B[i,n] = m + n * 100; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 100}, kInt); |
| BufHandle b_buf("B", {20, 100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle n("n", kInt); |
| auto forJ = For::make( |
| j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); |
| auto forN = For::make( |
| n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); |
| auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *forI; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, j] = |
| # CHECK-NEXT: B[i, j] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsDifferentStopBounds) { |
| // Input IR: |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < 50; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsDifferentStartBounds) { |
| // Input IR: |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 50; k < 100; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsNotContiguous) { |
| // Input IR: |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // B[0] = 0; |
| // for (int k = 0; k < 100; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto initB = Store::make(b_buf, {0}, 0); |
| auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, initB, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithDifferentParents) { |
| // Input IR: |
| // for (int i = 0; i < 50; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| // B[0] = 0; |
| // for (int k = 50; k < 100; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {50, 100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); |
| auto forI = For::make(i, 0, 50, forJ); |
| auto initB = Store::make(b_buf, {0}, 0); |
| auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forI, initB, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithVariableBounds) { |
| // Input IR: |
| // for (int j = 0; j < N; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < N; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle N("N", kInt); |
| auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
| auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: A[j] = |
| # CHECK-NEXT: B[j] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithExprBounds) { |
| // Input IR: |
| // for (int j = 0; j < M + N; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < M + N; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle M("M", kInt); |
| VarHandle N("N", kInt); |
| auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: A[j] = |
| # CHECK-NEXT: B[j] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { |
| // Input IR: |
| // for (int j = M; j < N * 2; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = M; k < N + N; k++) { |
| // B[k] = 20 * k; |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle M("M", kInt); |
| VarHandle N("N", kInt); |
| auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) |
| auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: A[j] = |
| # CHECK-NEXT: B[j] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { |
| // Input IR: |
| // for (int j = 10; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 10; k < 100; k++) { |
| // A[k+100] = 30 * k |
| // } |
| BufHandle a_buf("A", {200}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = |
| For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); |
| auto par = Block::make({forJ, forK}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int j |
| # CHECK-NEXT: A[j] = |
| # CHECK-NEXT: A[j + 100] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forJ); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j * 500; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 50; n++) { |
| // A[m+20,n+100] = m + n * 100; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 100}, kInt); |
| BufHandle b_buf("B", {20, 50}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); |
| auto forJ = For::make(j, 0, 100, storeA1); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto storeA2 = |
| Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); |
| auto forN = For::make(n, 0, 50, storeA2); |
| auto forM = For::make(m, 0, 20, forN); |
| auto par = Block::make({forI, forM}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, j] = |
| # CHECK: for (int n |
| # CHECK-NEXT: A[i + 20, n + 100] = |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithReductions) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // A[i] = 0 |
| // for (int j = 0; j < 100; j++) { |
| // A[i] = A[i] + B[i,j]; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // C[m] = A[m]; |
| // } |
| BufHandle a_buf("A", {20}, kInt); |
| BufHandle b_buf("B", {20, 100}, kInt); |
| BufHandle c_buf("C", {20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| auto initA = Store::make(a_buf, {i}, 0); |
| auto sumA = Store::make( |
| a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); |
| auto forJ = For::make(j, 0, 100, sumA); |
| auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); |
| auto forM = |
| For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); |
| auto par = Block::make({forI, forM}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: A[i] = |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i] = (A[i]) + |
| # CHECK-NOT: for ( |
| # CHECK: C[i] = A[i] |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsWith2DReductions) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 50; j++) { |
| // A[i,j] = 0 |
| // for (int k = 0; k < 100; k++) { |
| // A[i,j] = A[i,j] + B[i,j,k]; |
| // } |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 40; n++) { |
| // C[m,n] = A[m,n]; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 50}, kInt); |
| BufHandle b_buf("B", {20, 50, 100}, kInt); |
| BufHandle c_buf("C", {20, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto initA = Store::make(a_buf, {i, j}, 0); |
| auto sumA = Store::make( |
| a_buf, |
| {i, j}, |
| Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); |
| auto forK = For::make(k, 0, 100, sumA); |
| auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); |
| auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); |
| auto par = Block::make({forI, forM}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, j] = |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: A[i, j] = (A[i, j]) + |
| # CHECK: for (int n |
| # CHECK-NEXT: C[i, n] = A[i, n] |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithComplexIndices) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 20; j++) { |
| // A[i,j*20+j+2] = i + j; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 20; n++) { |
| // B[m,n] = A[m,n*20+n+2]; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 400}, kInt); |
| BufHandle b_buf("B", {20, 400}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); |
| auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
| auto storeB = |
| Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); |
| auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
| auto par = Block::make({forI, forM}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j |
| # CHECK: for (int n |
| # CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] |
| # CHECK-NOT: for ( |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| // The fused loop must be the same as the first loop. |
| ASSERT_EQ(fused_loop, forI); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 20; j++) { |
| // A[i,i*20+j] = i + j; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 20; n++) { |
| // B[m,n] = A[m,m*20+n]; // Both indices of A use m |
| // } |
| // } |
| BufHandle a_buf("A", {20, 500}, kInt); |
| BufHandle b_buf("B", {20, 500}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); |
| auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
| auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); |
| auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
| auto par = Block::make({forI, forM}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsWithTranspose) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 20; j++) { |
| // A[i,j] = i + j; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 20; n++) { |
| // B[m,n] = A[n,m]; // Transpose |
| // } |
| // } |
| BufHandle a_buf("A", {20, 20}, kInt); |
| BufHandle b_buf("B", {20, 20}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto writeA = Store::make(a_buf, {i, j}, i + j); |
| auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); |
| auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); |
| auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); |
| auto par = Block::make({forI, forM}); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies1) { |
| // Input IR: |
| // for (int j = 10; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 10; k < 100; k++) { |
| // A[k-1] = 20 * k; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = |
| For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies2) { |
| // Input IR: |
| // for (int j = 10; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 10; k < 100; k++) { |
| // A[k+50] = 20 * k; |
| // } |
| BufHandle a_buf("A", {150}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = |
| For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies3) { |
| // Input IR: |
| // for (int m = 0; m < 20; m++) { |
| // A[m] = 0; |
| // for (int j = 0; j < 100; j++) { |
| // A[m] = A[m] + m * j; |
| // } |
| // } |
| // for (int n = 0; n < 20; n++) { |
| // B[n] = A[n+1]; |
| // for (int k = 0; k < 50; k++) { |
| // B[n] = B[n] + n * k; |
| // } |
| // } |
| BufHandle a_buf("A", {25, 100}, kInt); |
| BufHandle b_buf("B", {20, 50}, kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto initA = Store::make(a_buf, {m}, 0); |
| auto forJ = For::make( |
| j, |
| 0, |
| 100, |
| Store::make( |
| a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); |
| auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); |
| auto forK = For::make( |
| k, |
| 0, |
| 50, |
| Store::make( |
| b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); |
| auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); |
| auto forN = For::make(n, 0, 20, Block::make({initB, forK})); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forM, forN}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies4) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j * 500; |
| // } |
| // } |
| // for (int m = 0; m < 20; m++) { |
| // for (int n = 0; n < 50; n++) { |
| // A[m+1,n] = m + n * 100; |
| // } |
| // } |
| BufHandle a_buf("A", {30, 100}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle m("m", kInt); |
| VarHandle n("n", kInt); |
| auto forI = For::make( |
| i, |
| 0, |
| 20, |
| For::make( |
| j, |
| 0, |
| 100, |
| Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); |
| auto forM = For::make( |
| m, |
| 0, |
| 20, |
| For::make( |
| n, |
| 0, |
| 50, |
| Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forI, forM}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies5) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 100; j++) { |
| // A[i,j] = i * j * 500; |
| // } |
| // for (int n = 0; n < 100; n++) { |
| // A[i,n+1] = m + n * 100; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle n("n", kInt); |
| auto forJ = For::make( |
| j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); |
| auto forN = For::make( |
| n, |
| 0, |
| 100, |
| Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) |
| auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies6) { |
| // Input IR: |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| // for (int k = 0; k < 100; k++) { |
| // B[k] = 20 * A[99-k]; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| auto forK = For::make( |
| k, |
| 0, |
| 100, |
| Store::make( |
| b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forJ, forK}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, fuseLoopsThatViolateDependencies7) { |
| // Input IR: |
| // for (int k = 0; k < 100; k++) { |
| // B[k] = 20 * A[99-k]; |
| // } |
| // for (int j = 0; j < 100; j++) { |
| // A[j] = 10 * j; |
| // } |
| BufHandle a_buf("A", {100}, kInt); |
| BufHandle b_buf("B", {100}, kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto forK = For::make( |
| k, |
| 0, |
| 100, |
| Store::make( |
| b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); |
| auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forK, forJ}); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| ForPtr fused_loop; |
| ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); |
| } |
| |
| TEST(LoopNest, areLoopsPerfectlyNested) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // for (int k = 0; k < 40; k++) { |
| // A[i,j,k] = i * j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
| auto forK = For::make(k, 0, 40, store); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forI}); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
| |
| // Specifying the loops in any other order fails. |
| ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); |
| ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); |
| ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); |
| |
| // Adding a statement to forK body should be OK. |
| auto init = Store::make(a_buf, {i, j}, 0); |
| forK->body()->insert_stmt_before(init, store); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
| |
| // Adding a statement in forJ body should fail this test. |
| forK->body()->remove_stmt(init); |
| forJ->body()->insert_stmt_before(init, forK); |
| ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
| |
| // Similarly, adding a statement in forI body should fail this test. |
| forJ->body()->remove_stmt(init); |
| forI->body()->insert_stmt_before(init, forJ); |
| ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
| } |
| |
| TEST(LoopNest, reorderNestedLoops2D) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // A[i,j] = i * j; |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); |
| auto forJ = For::make(j, 0, 30, store); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto par = Block::make({forI}); |
| |
| auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); |
| |
| ASSERT_EQ(reordered[0], forJ); |
| ASSERT_EQ(reordered[1], forI); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); |
| ASSERT_EQ(forJ->get_parent(), par); |
| ASSERT_EQ(store->get_parent(), forI->body()); |
| } |
| |
| TEST(LoopNest, reorderNestedLoops3D) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // for (int k = 0; k < 40; k++) { |
| // A[i,j,k] = i * j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
| auto forK = For::make(k, 0, 40, store); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto par = Block::make({forI}); |
| |
| auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); |
| |
| ASSERT_EQ(reordered[0], forK); |
| ASSERT_EQ(reordered[1], forI); |
| ASSERT_EQ(reordered[2], forJ); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); |
| ASSERT_EQ(forK->get_parent(), par); |
| ASSERT_EQ(store->get_parent(), forJ->body()); |
| } |
| |
| TEST(LoopNest, reorderNestedLoops4D) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // for (int k = 0; k < 40; k++) { |
| // for (int l = 0; l < 50; l++) { |
| // A[i,j,k,l] = i * j * k * l * 500; |
| // } |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40, 50}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle l("l", kInt); |
| auto store = Store::make( |
| a_buf, |
| {i, j, k, l}, |
| Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); |
| auto forL = For::make(l, 0, 50, store); |
| auto forK = For::make(k, 0, 40, forL); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto par = Block::make({forI}); |
| |
| auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); |
| |
| ASSERT_EQ(reordered[0], forK); |
| ASSERT_EQ(reordered[1], forI); |
| ASSERT_EQ(reordered[2], forL); |
| ASSERT_EQ(reordered[3], forJ); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); |
| ASSERT_EQ(forK->get_parent(), par); |
| ASSERT_EQ(store->get_parent(), forJ->body()); |
| } |
| |
| TEST(LoopNest, reorderTrivialPermutation) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // for (int k = 0; k < 40; k++) { |
| // A[i,j,k] = i * j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
| auto forK = For::make(k, 0, 40, store); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| auto par = Block::make({forI}); |
| |
| auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); |
| |
| ASSERT_EQ(reordered[0], forI); |
| ASSERT_EQ(reordered[1], forJ); |
| ASSERT_EQ(reordered[2], forK); |
| ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); |
| ASSERT_EQ(forI->get_parent(), par); |
| ASSERT_EQ(store->get_parent(), forK->body()); |
| } |
| |
| TEST(LoopNest, reorderInvalidPermutations) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // for (int k = 0; k < 40; k++) { |
| // A[i,j,k] = i * j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
| auto forK = For::make(k, 0, 40, store); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forI}); |
| |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), |
| "invalid permutation size"); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {1, 2}), |
| "invalid permutation size"); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), |
| "invalid permutation for reorder"); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), |
| "invalid permutation for reorder"); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), |
| "invalid permutation for reorder"); |
| } |
| |
| TEST(LoopNest, reorderInvalidLoopNest) { |
| // Input IR: |
| // for (int i = 0; i < 20; i++) { |
| // for (int j = 0; j < 30; j++) { |
| // A[i,j] = 0 |
| // for (int k = 0; k < 40; k++) { |
| // A[i,j,k] = i * j * k; |
| // } |
| // } |
| // } |
| BufHandle a_buf("A", {20, 30, 40}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); |
| auto forK = For::make(k, 0, 40, store); |
| auto forJ = For::make(j, 0, 30, forK); |
| auto forI = For::make(i, 0, 20, forJ); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| auto par = Block::make({forI}); |
| |
| // Specifying the loops in incorrect order fails. |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), |
| "reorder is only allowed on perfectly nested loops"); |
| |
| // Adding a statement to forJ loop fails. |
| auto init = Store::make(a_buf, {i}, 0); |
| forJ->body()->insert_stmt_before(init, forK); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), |
| "reorder is only allowed on perfectly nested loops"); |
| |
| // Moving that statement to forI loop also fails. |
| forJ->body()->remove_stmt(init); |
| forI->body()->insert_stmt_before(init, forJ); |
| ASSERT_THROWS_WITH( |
| LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), |
| "reorder is only allowed on perfectly nested loops"); |
| } |
| |
| TEST(LoopNest, compressBufferSimple) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[i,j] = sin(i*j) |
| // } |
| // for (int j = 0; j < 199; ++j) { |
| // B[i,j] = A[i,j] + A[i, j+1] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
| auto forJ2 = For::make( |
| j, |
| 0, |
| 199, |
| Store::make( |
| bBuf, |
| {i, j}, |
| Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
| auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[0, j] = |
| # CHECK: for (int j |
| # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
| } |
| |
| TEST(LoopNest, compressBufferMultipleDims) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[i,j] = sin(i*j) |
| // B[i,j] = A[i,j] + A[i,j] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); |
| auto store2 = Store::make( |
| bBuf, |
| {i, j}, |
| Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); |
| auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); |
| auto forI = For::make(i, 0, 100, forJ); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[0, 0] = |
| # CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
| } |
| |
| TEST(LoopNest, compressBufferMultipleDims2) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // for (int k = 0; k < 300; ++k) { |
| // A[i,j,k] = sin(i*j*k) |
| // } |
| // for (int k = 0; k < 299; ++j) { |
| // B[i,j,k] = A[i,j,k] + A[i,j,k+1] |
| // } |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200, 300}, kInt); |
| BufHandle bBuf("B", {100, 200, 300}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); |
| auto forK1 = For::make(k, 0, 300, store1); |
| auto store2 = Store::make( |
| bBuf, |
| {i, j, k}, |
| Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); |
| auto forK2 = For::make(k, 0, 299, store2); |
| auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); |
| auto forI = For::make(i, 0, 100, forJ); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: for (int k |
| # CHECK-NEXT: A[0, 0, k] = |
| # CHECK: for (int k |
| # CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 3); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); |
| } |
| |
| TEST(LoopNest, compressBufferDifferentOrderIndices) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[j, i] = sin(i*j) |
| // } |
| // for (int j = 0; j < 99; ++j) { |
| // B[i, j] = A[j, i] + A[j+1, 0] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); |
| auto forJ2 = For::make( |
| j, |
| 0, |
| 99, |
| Store::make( |
| bBuf, |
| {i, j}, |
| Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); |
| auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[j, 0] = |
| # CHECK: for (int j |
| # CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); |
| } |
| |
| TEST(LoopNest, compressBufferVariableBounds) { |
| // Input IR: |
| // for (int i = 0; i < M; ++i) { |
| // for (int j = 0; j < N; ++j) { |
| // A[i,j] = sin(i*j) |
| // } |
| // for (int j = 0; j < N-1; ++j) { |
| // B[i,j] = A[i,j] + A[i, j+1] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle M("M", kInt); |
| VarHandle N("N", kInt); |
| auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); |
| auto forJ2 = For::make( |
| j, |
| 0, |
| N - 1, |
| Store::make( |
| bBuf, |
| {i, j}, |
| Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
| auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[0, j] = |
| # CHECK: for (int j |
| # CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
| } |
| |
| TEST(LoopNest, compressBufferNoCommonParentLoops) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[i,j] = sin(i*j) |
| // } |
| // } |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 199; ++j) { |
| // B[i,j] = A[i,j] + A[i, j+1] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
| auto forJ2 = For::make( |
| j, |
| 0, |
| 199, |
| Store::make( |
| bBuf, |
| {i, j}, |
| Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); |
| auto forI1 = For::make(i, 0, 100, forJ1); |
| auto forI2 = For::make(i, 0, 100, forJ2); |
| auto par = Block::make({forI1, forI2}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| // There should be no change in the buffer or code. |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i, j] = |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
| } |
| |
| TEST(LoopNest, compressBufferIndicesMixed) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[i + j, j] = sin(i*j) |
| // } |
| // for (int j = 0; j < 199; ++j) { |
| // B[i,j] = A[i + j, j] + A[i + j, j+1] |
| // } |
| // } |
| BufHandle aBuf("A", {300, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); |
| auto forJ2 = For::make( |
| j, |
| 0, |
| 199, |
| Store::make( |
| bBuf, |
| {i, j}, |
| Add::make( |
| Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); |
| auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); |
| auto par = Block::make({forI}); |
| LoopNest::compressBuffer(aBuf.node(), par); |
| |
| // There should be no change in the buffer or code. |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[i + j, j] = |
| # CHECK: for (int j |
| # CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
| } |
| |
| TEST(LoopNest, compressMultipleBuffers) { |
| // Input IR: |
| // for (int i = 0; i < 100; ++i) { |
| // for (int j = 0; j < 200; ++j) { |
| // A[i,j] = sin(i*j) |
| // } |
| // for (int k = 0; k < 199; ++k) { |
| // B[i,k] = A[i,k] + A[i, k+1] |
| // } |
| // for (int m = 0; m < 50; ++m) { |
| // C[i,m] = B[i,m] |
| // } |
| // } |
| BufHandle aBuf("A", {100, 200}, kInt); |
| BufHandle bBuf("B", {100, 200}, kInt); |
| BufHandle cBuf("C", {100, 200}, kInt); |
| VarHandle i("i", kInt); |
| VarHandle j("j", kInt); |
| VarHandle k("k", kInt); |
| VarHandle m("m", kInt); |
| auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); |
| auto forK = For::make( |
| k, |
| 0, |
| 199, |
| Store::make( |
| bBuf, |
| {i, k}, |
| Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); |
| auto forM = |
| For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); |
| auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); |
| auto par = Block::make({forI}); |
| |
| // This should compress all buffers A, B, and C as follows: |
| // A[100, 200] -> A[1, 200] |
| // B[100, 200] -> B[1, 200] |
| // C[100, 200] -> C[1, 1] |
| LoopNest::compressAllBuffers(par); |
| |
| std::ostringstream oss; |
| oss << *par; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i |
| # CHECK-NEXT: for (int j |
| # CHECK-NEXT: A[0, j] = |
| # CHECK: for (int k |
| # CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) |
| # CHECK: for (int m |
| # CHECK-NEXT: C[0, 0] = B[0, m] |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| |
| ASSERT_EQ(aBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); |
| ASSERT_EQ(bBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); |
| ASSERT_EQ(cBuf.node()->ndim(), 2); |
| IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); |
| IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); |
| } |
| |
| TEST(LoopNest, sanitizeNames) { |
| std::vector<ExprHandle> dim_args; |
| // Let's pick names that would overlap with default index names if not |
| // sanitized properly: |
| dim_args.emplace_back(ExprHandle(alloc<Var>("i", kInt))); |
| dim_args.emplace_back(ExprHandle(alloc<Var>("N:2", kInt))); |
| // Now let's create a many dimensions so that we had to use the same letter |
| // for different loops |
| for (int i = 0; i < 10; i++) { |
| dim_args.emplace_back(ExprHandle(alloc<Var>("N", kInt))); |
| } |
| |
| // Now create two Computes with conflicting after sanitization names: |
| Tensor X = Compute("$X:!", dim_args, [&](const std::vector<VarHandle>& v) { |
| return v[0] + v[1] + v[9] + 1; |
| }); |
| Tensor Y = Reduce( |
| "%X\"+", |
| {}, |
| Sum(), |
| [&](const std::vector<VarHandle>& v) { return X.load(v); }, |
| dim_args); |
| |
| // Finally, let's verify what we got after sanitization: |
| LoopNest l({X, Y}); |
| StmtPtr s = l.root_stmt(); |
| LoopNest::sanitizeNames(s); |
| |
| std::ostringstream oss; |
| oss << *s; |
| const std::string& verification_pattern = |
| R"IR( |
| # CHECK: for (int i = 0; i < i_1; i++) { |
| # CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { |
| # CHECK-NEXT: for (int k = 0; k < N_9; k++) { |
| # CHECK-NEXT: for (int l = 0; l < N_8; l++) { |
| # CHECK-NEXT: for (int m = 0; m < N_7; m++) { |
| # CHECK-NEXT: for (int n = 0; n < N_6; n++) { |
| # CHECK-NEXT: for (int o = 0; o < N_5; o++) { |
| # CHECK-NEXT: for (int p = 0; p < N_4; p++) { |
| # CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { |
| # CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { |
| # CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { |
| # CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { |
| # CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; |
| # CHECK: v_X___1 = int(0); |
| # CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { |
| # CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { |
| # CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { |
| # CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { |
| # CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { |
| # CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { |
| # CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { |
| # CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { |
| # CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { |
| # CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { |
| # CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { |
| # CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { |
| # CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); |
| )IR"; |
| torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); |
| } |
| |
| } // namespace jit |
| } // namespace torch |