| #include <gtest/gtest.h> |
| #include <test/cpp/tensorexpr/test_base.h> |
| |
| #include <torch/csrc/jit/tensorexpr/bounds_overlap.h> |
| #include <torch/csrc/jit/tensorexpr/ir.h> |
| #include <torch/csrc/jit/tensorexpr/ir_printer.h> |
| #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
| #include <torch/csrc/jit/tensorexpr/loopnest.h> |
| #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h> |
| #include <torch/csrc/jit/tensorexpr/tensor.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| using namespace torch::jit::tensorexpr; |
| |
| // Test helper function used to determine if two regions of a buffer have an |
| // overlap. No Overlap & partial overlap is obvious. Contains means A is |
| // larger and fully encloses B, while ContainedOrEqual is the reverse. Equal |
| // ranges are ContainedOrEqual. |
| TEST(MemDependency, BoundOverlap) { |
| using namespace analysis; |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| |
| // Sanity check 3 overlap cases. |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); |
| |
| // Partial overlap works in either order. |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); |
| |
| // Total Overlap works when one bound encloses the other, and returns which. |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); |
| |
| // Total overlap works when the bounds are an identical range, returns |
| // ContainedOrEqual. |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); |
| |
| // Total overlap when only one end of the bound matches. |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); |
| |
| // No overlap when a < b. |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); |
| |
| // No overlap when a > b. |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); |
| |
| // No overlap when adjacent. |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); |
| |
| // Partial overlap when middle bounds match. |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); |
| |
| // Total overlap when one bound is single length over one end of the other. |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); |
| ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); |
| } |
| |
| TEST(MemDependency, BoundComparison) { |
| using namespace analysis; |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); |
| |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); |
| |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); |
| |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); |
| |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); |
| |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); |
| ASSERT_EQ( |
| CmpEvalResult::True, |
| compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); |
| ASSERT_EQ( |
| CmpEvalResult::False, |
| compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); |
| ASSERT_EQ( |
| CmpEvalResult::NotDetermined, |
| compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); |
| } |
| |
| TEST(MemDependency, BoundOverlapSymbolic) { |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| VarHandle z("z", kInt); |
| VarHandle w("w", kInt); |
| |
| using namespace analysis; |
| |
| auto CB = [](ExprHandle s, ExprHandle e) { |
| return Bound(s.node(), e.node()); |
| }; |
| |
| // Sanity check cases where the start and end is symbolic but the diff is |
| // constant. |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); |
| ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); |
| |
| // We can't infer the sign of y, so cannot tell whether adding y is larger or |
| // smaller than y/2. |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| boundOverlap(CB(x, x + y), CB(x, x + y / 2))); |
| |
| // No information about this bound, have to take the most conservative option: |
| // there may be an overlap. |
| ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); |
| |
| // Math on opaque terms works. |
| ASSERT_EQ( |
| OverlapKind::ContainedOrEqual, |
| boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); |
| // Even requiring simplification. |
| ASSERT_EQ( |
| OverlapKind::ContainedOrEqual, |
| boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); |
| } |
| |
| // Tests the helper function for overlap of multi dimensional indices bounds. |
| // This uses boundOverlap on each dimension and return the "lowest" kind of |
| // overlap. |
| TEST(MemDependency, BoundOverlapMultiDim) { |
| using namespace analysis; |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| |
| // Sanity check one dimensional cases. |
| ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); |
| ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); |
| |
| // Total overlap in 3 dims. |
| ASSERT_EQ( |
| OverlapKind::ContainedOrEqual, |
| overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); |
| ASSERT_EQ( |
| OverlapKind::ContainedOrEqual, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); |
| |
| // Total overlap in 2 dims, no overlap in another. |
| ASSERT_EQ( |
| OverlapKind::NoOverlap, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); |
| |
| // Total overlap in 2 dims, partial overlap in another. |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); |
| // This case is most important, so verify the overlap in any dim. (dim 2) |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); |
| // Dim 1. |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); |
| // Total overlap in 1 dim, partial in 2. |
| ASSERT_EQ( |
| OverlapKind::PartialOverlap, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); |
| // Total overlap, partial overlap, no overlap. |
| ASSERT_EQ( |
| OverlapKind::NoOverlap, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); |
| |
| // Total overlap (B) in 2 dims, total overlap (A) in another. |
| ASSERT_EQ( |
| OverlapKind::Contains, |
| overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); |
| |
| // Total overlap (A) in 2 dims, total overlap (B) in another. |
| ASSERT_EQ( |
| OverlapKind::Contains, |
| overlaps( |
| {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); |
| |
| // Total (B), No Overlap, Total (A). |
| ASSERT_EQ( |
| OverlapKind::NoOverlap, |
| overlaps( |
| {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); |
| } |
| |
| // Test the helper we use to subtract bounds: returns the regions(s) of A which |
| // remain after removing the region of B. |
| TEST(MemDependency, BoundSubtract) { |
| using namespace analysis; |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| // One element subtract. |
| ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); |
| ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); |
| |
| // No Overlap. |
| ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); |
| |
| // one side overlap. |
| ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); |
| |
| // both sides overlap. |
| ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); |
| ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); |
| |
| // internal overlap. |
| ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); |
| } |
| |
| TEST(MemDependency, BoundSubtractSymbolic) { |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| VarHandle z("z", kInt); |
| VarHandle w("w", kInt); |
| |
| using namespace analysis; |
| |
| auto CB = [](ExprHandle s, ExprHandle e) { |
| return Bound(s.node(), e.node()); |
| }; |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| // One element subtract. |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
| ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); |
| |
| // Subtract constant range low. |
| ASSERT_TRUE( |
| EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); |
| // Subtract constant range high. |
| ASSERT_TRUE( |
| EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); |
| // Subtract constant range total overlap. |
| ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); |
| // Subtract constant range internal. |
| ASSERT_TRUE( |
| EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), |
| {CB(x, x + 2), CB(x + 8, x + 10)})); |
| |
| // Size is inferable but not constant, only works with a single var. |
| ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); |
| ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); |
| |
| // Size is not inferable. |
| ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); |
| ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); |
| } |
| |
| // Tests the helper function that does subtraction, but for multi dimensional |
| // indices bounds. |
| TEST(MemDependency, BoundSubtractMultiDim) { |
| using namespace analysis; |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) { |
| if (x.size() != y.size()) { |
| return false; |
| } |
| for (auto i = 0U; i < x.size(); ++i) { |
| if (!indexBoundsEquals(x[i], y[i])) { |
| return false; |
| } |
| } |
| return true; |
| }; |
| |
| // sanity check one dimension. |
| ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); |
| ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); |
| |
| // Multi dim total overlap. |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); |
| |
| // Mutli dim one way partial in dim 1. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), |
| {{CB(4, 9), CB(0, 2)}})); |
| |
| // Mutli dim one way partial in dim 2. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), |
| {{CB(0, 9), CB(11, 20)}})); |
| |
| // Partial overlap in 2 dims. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), |
| {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); |
| |
| // Partial overlap in 3 dims. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds( |
| {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), |
| {{CB(0, 1), CB(0, 5), CB(0, 5)}, |
| {CB(2, 5), CB(0, 1), CB(0, 5)}, |
| {CB(2, 5), CB(2, 5), CB(0, 1)}})); |
| } |
| |
| // Tests the multi dimensional subtraction code for bounds that cannot be fully |
| // materialized. |
| TEST(MemDependency, BoundSubtractMultiDimSymbolic) { |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| |
| using namespace analysis; |
| |
| auto CB = [](ExprHandle s, ExprHandle e) { |
| return Bound(s.node(), e.node()); |
| }; |
| |
| auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) { |
| if (x.size() != y.size()) { |
| return false; |
| } |
| for (auto i = 0U; i < x.size(); ++i) { |
| if (!indexBoundsEquals(x[i], y[i])) { |
| return false; |
| } |
| } |
| return true; |
| }; |
| |
| // Cannot determine overlaps. |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
| ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); |
| |
| // Various total Overlaps. |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); |
| |
| // one-way overlap in first dim. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), |
| {{CB(x - 4, x), CB(0, y)}})); |
| // second dim. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), |
| {{CB(0, x), CB(0, 4)}})); |
| |
| // Internal overlap in first dim. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), |
| {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); |
| // second dim. |
| ASSERT_TRUE(EQ( |
| subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), |
| {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); |
| |
| // Overlap in both dimensions. |
| ASSERT_TRUE( |
| EQ(subtractIndicesBounds( |
| {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), |
| { |
| {CB(0, 4), CB(0, y)}, |
| {CB(x - 4, x), CB(0, y)}, |
| {CB(0, x), CB(0, 9)}, |
| {CB(0, x), CB(y - 9, y)}, |
| })); |
| } |
| |
| // Simple check that the analyzer does anything at all... |
| TEST(MemDependency, MemDependencyCheckerSimple) { |
| BufHandle a("A", {1}, kInt); |
| BufHandle b("B", {1}, kInt); |
| |
| analysis::MemDependencyChecker analyzer; |
| |
| /* |
| * A[0] = 3; |
| * B[0] = A[0] + 1; |
| */ |
| |
| StorePtr aStore = Store::make(a, {0}, 3); |
| StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); |
| |
| StmtPtr stmt = Block::make({aStore, bStore}); |
| |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); |
| // sanity check, but anything that depends directly must depend indirectly. |
| ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); |
| } |
| |
| // Check that there is a difference between direct and indirect dependence. |
| TEST(MemDependency, MemDependencyCheckerMultiStmt) { |
| BufHandle a("A", {1}, kInt); |
| BufHandle b("B", {1}, kInt); |
| BufHandle c("C", {1}, kInt); |
| |
| analysis::MemDependencyChecker analyzer; |
| |
| /* |
| * A[0] = 3; |
| * B[0] = A[0]; |
| * C[0] = B[0] + 1; |
| */ |
| |
| StorePtr aStore = Store::make(a, {0}, 3); |
| StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); |
| StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); |
| |
| StmtPtr stmt = Block::make({aStore, bStore, cStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // C depends on A indirectly. |
| ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); |
| ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); |
| |
| // C depends on B directly, which depends on A directly. |
| ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); |
| |
| // Dependency goes top to bottom only. |
| ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); |
| } |
| |
| // Verify that we do filter writes that are totally overlapped by later writes. |
| TEST(MemDependency, MemDependencyCheckerOverlap) { |
| BufHandle a("A", {1}, kInt); |
| BufHandle b("B", {1}, kInt); |
| |
| analysis::MemDependencyChecker analyzer; |
| |
| /* |
| * A[0] = 3; |
| * A[0] = 6; |
| * B[0] = A[0] + 1; |
| */ |
| |
| StorePtr aStore = Store::make(a, {0}, 3); |
| StorePtr a2Store = Store::make(a, {0}, 6); |
| StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); |
| |
| StmtPtr stmt = Block::make({aStore, a2Store, bStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // B store depends on second A store but not first since it is completely |
| // overlapped. |
| ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); |
| |
| // No dependency between either A store. |
| ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); |
| } |
| |
| // Verify that bounds match loop iterations, and that dependencies progress |
| // across loop scopes. |
| TEST(MemDependency, MemDependencyCheckerLoop) { |
| BufHandle a("A", {1}, kInt); |
| BufHandle b("B", {1}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| MemDependencyChecker analyzer; |
| |
| /* |
| * for (int x = 0; x < 10; ++x) { |
| * A[x] = x; |
| * } |
| * B[0] = A[0] + 1; |
| */ |
| |
| StorePtr aStore = Store::make(a, {x}, x); |
| StmtPtr loop = For::make(x, 0, 10, aStore); |
| StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); |
| |
| StmtPtr stmt = Block::make({loop, bStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // Same A->B dependency. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); |
| |
| // B depends on the loop. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); |
| // A is in the loop but does not depend on any loop iteration. |
| ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); |
| |
| auto aStoreAccess = analyzer.accessFor(aStore); |
| ASSERT_NE(aStoreAccess, nullptr); |
| |
| // It should have bounds covering the range of x: 0 <= x < 10. |
| ASSERT_TRUE(indexBoundsEquals( |
| aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))})); |
| } |
| |
| // Reductions should promote dependencies as well. |
| TEST(MemDependency, MemDependencyCheckerLoopReduce) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| MemDependencyChecker analyzer; |
| |
| /* |
| * A[0] = 0; |
| * for (int x = 0; x < 10; ++x) { |
| * A[0] = A[x] + 1; |
| * } |
| * B[0] = A[0]; |
| */ |
| |
| StorePtr aInit = Store::make(a, {0}, 0); |
| ExprHandle reduce = Sum()(a, 1, {x}, {x}); |
| StorePtr aReduce = Store::make(a, {0}, reduce); |
| StmtPtr loop = For::make(x, 0, 10, aReduce); |
| StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); |
| |
| StmtPtr stmt = Block::make({aInit, loop, bStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // B -> A. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); |
| |
| // B depends indirectly on the initializer of A, since the reduction depends |
| // on it. |
| ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); |
| ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); |
| |
| ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); |
| |
| // B depends on the loop. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); |
| // A is in the loop and depends on other iterations. |
| ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); |
| |
| // The loop contents depend on the initializer too. |
| ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); |
| |
| // Find loads within the reduction: |
| auto reduceLoads = NodeFinder<Load>::find(reduce.node()); |
| // Pull out the access for the load inside the loop. |
| for (auto load : reduceLoads) { |
| auto loopLoad = analyzer.accessFor(load); |
| // It should have 10 element long bounds. |
| ASSERT_TRUE(indexBoundsEquals( |
| loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))})); |
| } |
| } |
| |
| // Lowering a reduction doesn't affect dependency analysis. |
| TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| MemDependencyChecker analyzer; |
| |
| /* |
| * A[0] = 0; |
| * for (int x = 0; x < 10; ++x) { |
| * A[0] = A[x] + 1; |
| * } |
| * B[0] = A[0]; |
| */ |
| |
| StorePtr aInit = Store::make(a, {0}, 0); |
| ExprHandle aLoad = Load::make(a, {x}); |
| StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); |
| StmtPtr loop = For::make(x, 0, 10, aReduce); |
| StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); |
| |
| StmtPtr stmt = Block::make({aInit, loop, bStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // B -> A. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); |
| |
| // B depends indirectly on the initializer of A, since the reduction depends |
| // on it. |
| ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); |
| ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); |
| |
| ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); |
| |
| // B depends on the loop. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); |
| // A is in the loop and depends on other iterations. |
| ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); |
| |
| // The loop contents depend on the initializer too. |
| ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); |
| |
| // Pull out the access for the store inside the loop. |
| auto loopLoad = analyzer.accessFor(aLoad.node()); |
| // It should have 10 element long bounds. |
| ASSERT_TRUE(indexBoundsEquals( |
| loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))})); |
| } |
| |
| // Can determine dependencies of outputs, through to inputs. |
| TEST(MemDependency, MemDependencyCheckerInputsOutputs) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| // initialize analyzer with inputs and outputs. |
| analysis::MemDependencyChecker analyzer({a}, {b}); |
| |
| // Here's a Relu. |
| /* |
| * for (int x = 0; x < 10; ++x) { |
| * B[x] = Max(A[x], 0); |
| * } |
| */ |
| |
| ExprHandle aLoad = Load::make(a, {x}); |
| StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); |
| StmtPtr loop = For::make(x, 0, 10, bStore); |
| |
| StmtPtr stmt = Block::make({loop}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output depends indirectly on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| // aLoad depends directly on the input A. |
| ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); |
| // bStore therefore depends directly on the input A. |
| ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); |
| // The output depends directly on the store. |
| ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); |
| |
| // Check AccessInfo based overloads. |
| auto input = analyzer.input(a.node()); |
| auto output = analyzer.output(b.node()); |
| |
| // Output depends indirectly on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); |
| // Not directly. |
| ASSERT_FALSE(analyzer.dependsDirectly(output, input)); |
| // Not in reverse order. |
| ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); |
| |
| // output -> bStore -> bLoad -> input. |
| auto storeAccess = analyzer.accessFor(bStore); |
| auto loadAccess = analyzer.accessFor(aLoad.node()); |
| |
| ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); |
| ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); |
| } |
| |
| // Can tell if an output does not depend on an input. |
| TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| // initialize analyzer with inputs and outputs. |
| analysis::MemDependencyChecker analyzer({a}, {b}); |
| |
| // Here's a dumb Relu. |
| /* |
| * for (int x = 0; x < 10; ++x) { |
| * B[x] = Max(x, 0); |
| * } |
| */ |
| |
| StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); |
| StmtPtr loop = For::make(x, 0, 10, bStore); |
| |
| StmtPtr stmt = Block::make({loop}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output does not depend indirectly on input. |
| ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // The output still depends directly on the store. |
| ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); |
| |
| // Check AccessInfo based overloads. |
| auto input = analyzer.input(a.node()); |
| auto output = analyzer.output(b.node()); |
| |
| // Output does not depend indirectly on input. |
| ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); |
| } |
| |
| // Verify different loop extents produce accesses with different bounds, and |
| // that later accesses find dependencies that overlap their entire bound range. |
| TEST(MemDependency, MemDependencyCheckerLoopBounds) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| BufHandle c("C", {10}, kInt); |
| VarHandle x("x", kInt); |
| using namespace analysis; |
| |
| MemDependencyChecker analyzer({a}, {c}); |
| |
| // This enables using the execution order of the loops to determine if some |
| // loops are self dependent or not. |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| /* |
| * for (int x = 1; x < 10; ++x) { |
| * B[x] = A[x]; |
| * } |
| * for (int x = 1; x < 9; ++x) { |
| * B[x] = B[x] * 2; |
| * } |
| * for (int x = 3; x < 4; ++x) { |
| * C[x] = A[x]; |
| * } |
| * for (int x = 0; x < 10; ++x) { |
| * C[x] = B[x]; |
| * } |
| */ |
| |
| std::vector<StmtPtr> stmts( |
| {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), |
| For::make( |
| x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), |
| For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), |
| For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); |
| |
| StmtPtr stmt = Block::make(stmts); |
| |
| stmt->accept(&analyzer); |
| |
| auto input = analyzer.input(a.node()); |
| auto output = analyzer.output(c.node()); |
| |
| // sanity check Output -> Input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); |
| |
| // Check the For loop dependencies: |
| |
| // Last write to C depends on both writes to B since they contain the last |
| // write to at least one element. |
| ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); |
| ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); |
| |
| // The last write to C does not depend on the other write to C. |
| ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| /* 0. Input: A[(0, 9)] - dependents: 1 5 |
| * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 |
| * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 |
| * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 |
| * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 |
| * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 |
| * 6. Store: C[(3, 3)] - depends on: 5 |
| * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 |
| * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 |
| * 9. Output: C[(0, 9)] - depends on: 8 |
| */ |
| |
| // Now let's look at the bounds of each access. |
| // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this |
| // much. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 10); |
| VarPtr aVar = a.node()->base_handle(); |
| VarPtr bVar = b.node()->base_handle(); |
| VarPtr cVar = c.node()->base_handle(); |
| |
| // The first access is the input A. |
| ASSERT_EQ(history[0]->type(), AccessType::Input); |
| ASSERT_EQ(history[0]->var(), aVar); |
| // It has the bounds of the producing Input. |
| ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); |
| // sanity check the input we retrieved earlier matches. |
| ASSERT_EQ(history[0], input); |
| |
| // The second access is the load of A in the first loop. |
| ASSERT_EQ(history[1]->type(), AccessType::Load); |
| ASSERT_EQ(history[1]->var(), aVar); |
| // It has the bounds of the loop, i.e. start == 1. |
| ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); |
| // It reads from A, so it should have a dependency on the last write to this |
| // range - with is the input. |
| ASSERT_EQ(history[1]->dependencies().size(), 1); |
| ASSERT_TRUE(history[1]->hasDependency(history[0])); |
| |
| // The third access is the store into B in the first loop. |
| ASSERT_EQ(history[2]->type(), AccessType::Store); |
| ASSERT_EQ(history[2]->var(), bVar); |
| // It also has the bounds of the loop, i.e. start == 1. |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); |
| // The previous load is in its RHS, so it depends on it. |
| ASSERT_EQ(history[2]->dependencies().size(), 1); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| |
| // The third access is the load from B in the second loop. |
| ASSERT_EQ(history[3]->type(), AccessType::Load); |
| ASSERT_EQ(history[3]->var(), bVar); |
| // It has the bounds of the second loop, i.e. >= 1 < 9. |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); |
| // It reads from B in a smaller range, so should depend on the previous |
| // store. |
| ASSERT_EQ(history[3]->dependencies().size(), 1); |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| |
| // The fourth: the store to B in the second loop. |
| ASSERT_EQ(history[4]->type(), AccessType::Store); |
| ASSERT_EQ(history[4]->var(), bVar); |
| // It also has the bounds of the second loop. |
| ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); |
| // The previous load is in its RHS, so it depends on it as before. |
| ASSERT_EQ(history[4]->dependencies().size(), 1); |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| |
| // The fifth access is the load is from the 3rd loop, and skips previous B |
| // accesses. |
| ASSERT_EQ(history[5]->type(), AccessType::Load); |
| ASSERT_EQ(history[5]->var(), aVar); |
| // It has the bounds of the third loop: >= 3 < 4. |
| ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); |
| // It depends on the last thing to write to A, which is the A input. |
| ASSERT_EQ(history[5]->dependencies().size(), 1); |
| ASSERT_TRUE(history[5]->hasDependency(history[0])); |
| |
| // Sixth: the store into the output C. |
| ASSERT_EQ(history[6]->type(), AccessType::Store); |
| ASSERT_EQ(history[6]->var(), cVar); |
| // It also has the bounds of the third loop. |
| ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); |
| // The previous load is in its RHS, so it depends on it as always. |
| ASSERT_EQ(history[6]->dependencies().size(), 1); |
| ASSERT_TRUE(history[6]->hasDependency(history[5])); |
| |
| // The seventh access is the load of B in the fourth loop. |
| ASSERT_EQ(history[7]->type(), AccessType::Load); |
| ASSERT_EQ(history[7]->var(), bVar); |
| // It has the bounds of the final loop, >= 0 < 10 |
| ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); |
| // The bounds of this read are larger than the bounds of the previous write, |
| // so it depends on both previous Stores to B. |
| ASSERT_EQ(history[7]->dependencies().size(), 2); |
| ASSERT_TRUE(history[7]->hasDependency(history[2])); |
| ASSERT_TRUE(history[7]->hasDependency(history[4])); |
| |
| // Eight: the final store into the output C. |
| ASSERT_EQ(history[8]->type(), AccessType::Store); |
| ASSERT_EQ(history[8]->var(), cVar); |
| // It also has the bounds of the final loop. |
| ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); |
| // The previous load is in its RHS, so it depends on it as always. |
| ASSERT_EQ(history[8]->dependencies().size(), 1); |
| ASSERT_TRUE(history[8]->hasDependency(history[7])); |
| |
| // The last access represents the output Buf. |
| ASSERT_EQ(history[9]->type(), AccessType::Output); |
| ASSERT_EQ(history[9]->var(), cVar); |
| // It has the bounds of the output Buf. |
| ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); |
| // sanity check the input we retrieved earlier matches. |
| ASSERT_EQ(history[9], output); |
| // It depends on the last write to C only. |
| ASSERT_EQ(history[9]->dependencies().size(), 1); |
| ASSERT_TRUE(history[9]->hasDependency(history[8])); |
| } |
| |
| // Verify that we can still infer bounds when the loop var is offset. |
| TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| |
| // This enables using the execution order of the loops to determine if some |
| // loops are self dependent or not. |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| /* |
| * for (int x = 1; x < 10; x++) { |
| * A[x] = A[x - 1]; |
| * } |
| * for (int x = 0; x < 9; x++) { |
| * A[x] = A[x + 1]; |
| * } |
| * for (int x = 0; x < 9; x++) { |
| * A[9 - x] = A[8 - x]; |
| * } |
| * for (int x = 0; x < 10; x++) { |
| * A[x] = A[9 - x]; |
| * } |
| * for (int x = 0; x < 10; x++) { |
| * B[x] = A[x]; |
| * } |
| */ |
| |
| StmtPtr stmt = Block::make( |
| {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), |
| For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), |
| For::make( |
| x, |
| 0, |
| 9, |
| Store::make( |
| a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), |
| For::make( |
| x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), |
| For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity check output depends on Input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| auto CB = [](int s, int e) { |
| return Bound(alloc<IntImm>(s), alloc<IntImm>(e)); |
| }; |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| /* 0. Input: A[(0, 9)] - dependents: 1 |
| * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 |
| * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 |
| * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 |
| * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 |
| * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 |
| * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 |
| * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 |
| * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 |
| * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 |
| * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 |
| * 11. Output: B[(0, 9)] - depends on: 10 |
| */ |
| |
| // Now let's look at the bounds of each access. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 12); |
| VarPtr aVar = a.node()->base_handle(); |
| VarPtr bVar = b.node()->base_handle(); |
| |
| // The first access is the input A. |
| ASSERT_EQ(history[0]->type(), AccessType::Input); |
| ASSERT_EQ(history[0]->var(), aVar); |
| // It has the bounds of the producing Input. |
| ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); |
| |
| // The second access is the load A[x-1]. |
| ASSERT_EQ(history[1]->type(), AccessType::Load); |
| ASSERT_EQ(history[1]->var(), aVar); |
| // It has the bounds of the loop modified by the offset of each index, in |
| // this case -1. |
| ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); |
| // It depends on the input, but also the store in the same loop, since |
| // different interations of the loop depend on each other. |
| ASSERT_EQ(history[1]->dependencies().size(), 2); |
| ASSERT_TRUE(history[1]->hasDependency(history[0])); |
| ASSERT_TRUE(history[1]->hasDependency(history[2])); |
| |
| // The third access is the Store to A[x] in the first loop. |
| ASSERT_EQ(history[2]->type(), AccessType::Store); |
| ASSERT_EQ(history[2]->var(), aVar); |
| // It has no offset on x, so should have the same bounds as the loop. |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); |
| |
| // The fourth access is the load A[x+1] in the second loop. |
| ASSERT_EQ(history[3]->type(), AccessType::Load); |
| ASSERT_EQ(history[3]->var(), aVar); |
| // It has the bounds of the loop (0 <= x < 9) modified by the offset of each |
| // index, in this case 1. |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); |
| // This load totally overlaps the previous write to A, so it depends only on |
| // it and not the input. |
| ASSERT_EQ(history[3]->dependencies().size(), 1); |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| |
| // The fifth access is the store to A[x] in the second loop. |
| ASSERT_EQ(history[4]->type(), AccessType::Store); |
| ASSERT_EQ(history[4]->var(), aVar); |
| // It has no offset on x, so should have the same bounds as the loop. |
| ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); |
| |
| // The sixth access is the load to A[8 - x] in the third loop. |
| ASSERT_EQ(history[5]->type(), AccessType::Load); |
| ASSERT_EQ(history[5]->var(), aVar); |
| // It has the bounds of the loop (0 <= x < 9) modified by the offset of each |
| // index, in this case 8 - x. |
| // This access has a negative stride, which will be normalized. |
| ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); |
| // This load totally overlaps the most recent write to A, so it depends only |
| // on it and not the input or the first write to A. |
| ASSERT_EQ(history[5]->dependencies().size(), 1); |
| ASSERT_TRUE(history[5]->hasDependency(history[4])); |
| |
| // The seventh access is the store to A[9 - x] in the third loop. |
| ASSERT_EQ(history[6]->type(), AccessType::Store); |
| ASSERT_EQ(history[6]->var(), aVar); |
| // This store has a negative stride on it's indices, but is normalized |
| // internally. |
| ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); |
| |
| // The eighth access is the load A[9-x] in the second loop. |
| ASSERT_EQ(history[7]->type(), AccessType::Load); |
| ASSERT_EQ(history[7]->var(), aVar); |
| // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, |
| // which essentially traverses the loop backwards. |
| ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); |
| // This Load has three write dependencies: |
| ASSERT_EQ(history[7]->dependencies().size(), 3); |
| // * The previous store (#6) for elements 1-9 |
| ASSERT_TRUE(history[7]->hasDependency(history[6])); |
| // * An earlier store (#4) covering element 0 |
| ASSERT_TRUE(history[7]->hasDependency(history[4])); |
| // * A future store inside this loop, since this loop modifies the buffer |
| // in a non distinct way (due to the load and store having different access |
| // strides). |
| ASSERT_TRUE(history[7]->hasDependency(history[8])); |
| |
| // The ninth access is the store to A[x] in the fourth loop. |
| ASSERT_EQ(history[8]->type(), AccessType::Store); |
| ASSERT_EQ(history[8]->var(), aVar); |
| // This store has a negative stride on it's indices, but is normalized |
| // internally. |
| ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); |
| |
| // The tenth and 11th accesses are the copy from A[x] to B[x]. |
| ASSERT_EQ(history[9]->type(), AccessType::Load); |
| ASSERT_EQ(history[9]->var(), aVar); |
| ASSERT_EQ(history[10]->type(), AccessType::Store); |
| ASSERT_EQ(history[10]->var(), bVar); |
| |
| // The last access represents the output Buf. |
| ASSERT_EQ(history[11]->type(), AccessType::Output); |
| ASSERT_EQ(history[11]->var(), bVar); |
| // It has the bounds of the output Buf. |
| ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); |
| // It depends on the last write to B only. |
| ASSERT_EQ(history[11]->dependencies().size(), 1); |
| ASSERT_TRUE(history[11]->hasDependency(history[10])); |
| |
| // ok that's enough of that. |
| } |
| |
| // Check many different cases of loop self dependency - when a load within a |
| // loop is dependent on a Store later in the same loop but in different |
| // iteration. This is affected by whether or not we can trust the execution |
| // order of the loop. |
| TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { |
| BufHandle a("A", {5}, kInt); |
| BufHandle b("B", {5}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| VarHandle z("z", kInt); |
| |
| using namespace analysis; |
| |
| // This check assumes that the Stmt has a single Store with a single Load on |
| // the RHS. |
| auto isSelfDependent = |
| [](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool { |
| return history.front()->hasDependency(history.back()); |
| }; |
| |
| { |
| /* for (int y = 0; y < 10; y++) { |
| * A[y] = (A[y]) + 1; |
| * } */ |
| |
| // Not self dependent since all loop iterations use a different y. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| y, |
| 0, |
| 10, |
| Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); |
| |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int y = 0; y < 10; y++) { |
| * A[y + 1] = (A[y + 1]) + 1; |
| * } |
| */ |
| |
| // Not self dependent due to different y (with offset). |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| y, |
| 0, |
| 10, |
| Block::make( |
| {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); |
| |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[0] = (A[0]) + x; |
| * } |
| */ |
| |
| // Is self dependent since all loops use a common constant element of A. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, |
| 0, |
| 10, |
| Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[0] = (B[0]) + x; |
| * } |
| */ |
| |
| // Is not self dependent because there is no store to the buffer that is |
| // read. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, |
| 0, |
| 10, |
| Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[y] = (A[y]) + x; |
| * } |
| */ |
| |
| // Is self dependent since all loops use a common symbolic element of A. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, |
| 0, |
| 10, |
| Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x] = A[x + 1]; |
| * } |
| */ |
| |
| // In this case it depends if we are considering execution order. |
| |
| MemDependencyChecker analyzer; |
| |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); |
| stmt->accept(&analyzer); |
| |
| // With analysis of order disabled, this is self dependent since the read |
| // from X+1 and the write to X+1 could be in reverse order. |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x] = A[x + 1]; |
| * } |
| */ |
| |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); |
| stmt->accept(&analyzer); |
| |
| // If order analysis is enabled, this is not dependent since the read for |
| // each element occurs before the write to that element. |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 1; x < 10; x++) { |
| * A[x] = A[x - 1]; |
| * } |
| */ |
| |
| MemDependencyChecker analyzer; |
| |
| StmtPtr stmt = |
| For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 1; x < 10; x++) { |
| * A[x] = A[x - 1]; |
| * } |
| */ |
| |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| StmtPtr stmt = |
| For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); |
| stmt->accept(&analyzer); |
| |
| // In this case, even with order analysis the Load is dependent on the |
| // Store, since the write to X occurs before the read from X. |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 9; x++) { |
| * A[9 - x] = A[8 - x]; |
| * } |
| */ |
| |
| // Still works if the execution order is reversed, so long as the read |
| // comes before the write. |
| |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| StmtPtr stmt = For::make( |
| x, |
| 3, |
| 10, |
| Store::make( |
| a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); |
| stmt->accept(&analyzer); |
| |
| // However here was can determine the A store is earlier in the order than |
| // the load. |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 9; x++) { |
| * A[8 - x] = A[9 - x]; |
| * } |
| */ |
| |
| // But not if it doesn't. |
| |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| StmtPtr stmt = For::make( |
| x, |
| 3, |
| 10, |
| Store::make( |
| a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 9; x++) { |
| * A[9 - x] = A[8 - x]; |
| * } |
| */ |
| |
| // And not if we're not relying on execution order. |
| |
| MemDependencyChecker analyzer; |
| |
| StmtPtr stmt = For::make( |
| x, |
| 3, |
| 10, |
| Store::make( |
| a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 3; x < 10; x++) { |
| * A[x - 2] = A[x - 1]; |
| * } |
| */ |
| |
| // Forward order but negative indices. |
| |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| |
| StmtPtr stmt = |
| For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); |
| stmt->accept(&analyzer); |
| |
| // However here was can determine the A store is earlier in the order than |
| // the load. |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2]; |
| * } |
| */ |
| |
| // With an access stride. |
| |
| MemDependencyChecker analyzer; |
| // Execution order doesn't matter since the read and the write are totally |
| // distinct. |
| |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 + 1]; |
| * } |
| */ |
| |
| // Here we can use the common stride of the accesses to determine they are |
| // distinct. |
| // Note, this is the only place (loop self dependency) we use this stride |
| // to avoid unnecessary dependence. |
| |
| MemDependencyChecker analyzer; |
| // Execution order doesn't matter since the read and the write are totally |
| // distinct. |
| |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 - 1]; |
| * } |
| */ |
| |
| // same if the read is behind the write so long as they are distinct. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 + 2]; |
| * } |
| */ |
| |
| // But not if the offset is in the stride. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 - 2]; |
| * } |
| */ |
| |
| // Works with negative offsets too. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 + 7]; |
| * } |
| */ |
| |
| // Detects accesses are distinct when offset is large but not a multiple |
| // of stride. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 2 + 4]; |
| * } |
| */ |
| |
| // Works with offsets which are multiples of the stride. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 6] = A[x * 6 + 5]; |
| * } |
| */ |
| |
| // detects accesses are distinct with large strides when the offset is |
| // within. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 6]; |
| * } |
| */ |
| |
| // detects accesses are overlapping when stride is different but a |
| // multiple. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 4] = A[x * 2]; |
| * } |
| */ |
| |
| // still works when the read axis is the smaller stride. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 6 + 1]; |
| * } |
| */ |
| |
| // detects accesses are distinct when stride is different but a multiple |
| // and there is an offset. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 6 + 4]; |
| * } |
| */ |
| |
| // The smaller stride determines whether there is overlap. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2 + 3] = A[x * 6]; |
| * } |
| */ |
| |
| // The smaller stride determines whether there is overlap, not the larger. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[x * 3 + 1]; |
| * } |
| */ |
| |
| // If they have strides with no common multiple > 1, they overlap. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x] = A[x + 10]; |
| * } |
| */ |
| |
| // If the offset is greater than the size of the loop, they can't overlap. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x] = A[9 - x]; |
| * } |
| */ |
| |
| // If they have different execution orders they may overlap. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x * 2] = A[19 - x * 2]; |
| * } |
| */ |
| |
| // Or they may not, depending on their start offset and strides. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, |
| 0, |
| 10, |
| Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x / 2] = A[x / 2]; |
| * } |
| */ |
| |
| // If the stride is not monotonic, they overlap. |
| |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = |
| For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x / 2] = A[x / 2] + 1; |
| * } |
| */ |
| |
| // If the stride is not monotonic, they overlap - even with an offset. |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * A[x % 2] = A[x % 2]; |
| * } |
| */ |
| |
| // Mod too... |
| |
| analysis::MemDependencyChecker analyzer; |
| StmtPtr stmt = For::make( |
| x, |
| 0, |
| 10, |
| Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| /* for (int x = y; x < z; x++) { |
| * A[x] = A[x + 1]; |
| * } |
| */ |
| |
| // Still works with symbolic loop extents. |
| |
| { |
| MemDependencyChecker analyzer; |
| StmtPtr stmt = |
| For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); |
| } |
| |
| { |
| MemDependencyChecker analyzer; |
| analyzer.allowLoopExecutionOrderAnalysis(); |
| StmtPtr stmt = |
| For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); |
| stmt->accept(&analyzer); |
| |
| ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); |
| } |
| } |
| } |
| |
| // Verify that a strided access still works. |
| // TODO: actually this only works because of the size of the ranges, revisit |
| // this test after strided overlap is implemented. |
| TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { |
| BufHandle a("A", {20}, kInt); |
| BufHandle b("B", {20}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| |
| using namespace analysis; |
| MemDependencyChecker analyzer({a.node()}, {b.node()}); |
| StmtPtr stmt = Block::make( |
| {For::make( |
| x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), |
| For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) |
| |
| }); |
| stmt->accept(&analyzer); |
| |
| // Sanity check output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // Output has 2 dependencies... the store in each loop. |
| auto outputAccess = analyzer.output(b.node()); |
| ASSERT_EQ(outputAccess->dependencies().size(), 2); |
| } |
| |
| /* TODO(nickg) - this test will fail due to the lack of stride math in Bound |
| TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { |
| BufHandle a("A", {20}, kInt); |
| BufHandle b("B", {20}, kInt); |
| BufHandle c("C", {10}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| |
| { |
| analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); |
| StmtPtr stmt = Block::make( |
| {For::make( |
| x, |
| 0, |
| 10, |
| Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), |
| For::make( |
| x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), |
| For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) |
| |
| }); |
| stmt->accept(&analyzer); |
| |
| std::cout << *stmt << "\n"; |
| for (auto& wi : analyzer.getHistory()) { |
| wi->print(); |
| } |
| } |
| }*/ |
| |
| // analysis on Stmts using Cond. |
| TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| BufHandle c("C", {10}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| |
| using namespace analysis; |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * if (y<5 ? 1 : 0) { |
| * C[0] = (B[0]) + 1; |
| * } else { |
| * C[0] = (B[1]) + 1; |
| * } |
| */ |
| |
| // Future usages may depend on accesses in both branches of a condition. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| Cond::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), |
| Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output C should have 3 dependencies, each of the three stores. |
| auto outputAccess = analyzer.output(c.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 3); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * if (y<5 ? 1 : 0) { |
| * for (int x = 0; x < 10; x++) { |
| * C[x] = B[x]; |
| * } |
| * } else { |
| * for (int x = 0; x < 10; x++) { |
| * C[x] = (B[x]) + 1; |
| * } |
| * } |
| */ |
| |
| // Future usages may depend on accesses in both branches of a condition. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| Cond::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), |
| For::make( |
| x, |
| 0, |
| 10, |
| Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output C should have 3 dependencies, each of the three stores. |
| auto outputAccess = analyzer.output(c.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 3); |
| |
| // TODO(nickg): actually since the true and false branch cover the total |
| // range of the first store this should have 2 dependencies, but we don't |
| // do that yet. |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * if (y<5 ? 1 : 0) { |
| * for (int x = 0; x < 10; x++) { |
| * C[x] = (B[x]) + 1; |
| * } |
| * } |
| */ |
| |
| // Only has true branch. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| Cond::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| For::make( |
| x, |
| 0, |
| 10, |
| Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), |
| nullptr)}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output C should have 3 dependencies, each of the three stores. |
| auto outputAccess = analyzer.output(c.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 2); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * if (y<5 ? 1 : 0) { |
| * } else { |
| * for (int x = 0; x < 10; x++) { |
| * C[x] = (B[x]) + 1; |
| * } |
| * } |
| */ |
| |
| // Only has false branch. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| Cond::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| nullptr, |
| For::make( |
| x, |
| 0, |
| 10, |
| Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output C should have 3 dependencies, each of the three stores. |
| auto outputAccess = analyzer.output(c.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 2); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * if (C[0]<5 ? 1 : 0) { |
| * C[0] = 5; |
| * } |
| */ |
| |
| // Cond's Condition depends on a previous access. |
| |
| MemDependencyChecker analyzer({a}, {c}); |
| StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); |
| ExprHandle conditionalLoad = Load::make(c, {0}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, initStore), |
| Cond::make( |
| CompareSelect::make( |
| conditionalLoad, 5, CompareSelectOperation::kLT), |
| Store::make(c, {0}, 5), |
| nullptr)}); |
| |
| stmt->accept(&analyzer); |
| |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| |
| ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); |
| ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); |
| } |
| } |
| |
| // Stmts using IfThenElse. |
| TEST(MemDependency, MemDependencyCheckerIfThenElse) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| BufHandle c("C", {10}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| |
| using namespace analysis; |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; |
| */ |
| |
| // Future usages may depend on accesses in both branches of a condition. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StorePtr ifStore = Store::make( |
| c, |
| {0}, |
| IfThenElse::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| Add::make(Load::make(b, {0}), 1), |
| Add::make(Load::make(b, {1}), 1))); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| ifStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output C should have 2 dependencies, each of the two stores. |
| auto outputAccess = analyzer.output(c.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 2); |
| |
| // Now we need to check the Store containing the IfThenElse. |
| auto ifStoreAccess = analyzer.accessFor(ifStore); |
| |
| // It should have 2 dependencies. |
| ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[x]; |
| * } |
| * C[0] = (y < 5 ? (B[0]) + 1 : 42; |
| */ |
| |
| // If the load appears in only one side of an IfThenElse the output may be |
| // dependent on it. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StorePtr ifStore = Store::make( |
| c, |
| {0}, |
| IfThenElse::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| Add::make(Load::make(b, {0}), 1), |
| 42)); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), |
| ifStore}); |
| |
| stmt->accept(&analyzer); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = (x < 5 ? B[x] : A[x]; |
| * } |
| */ |
| |
| // In this case C is dependent on both A and B. |
| |
| // TODO: in cases like this it would be possible to split the range of B |
| // into two bounds, one dependent on A and one dependent on B. We'd need to |
| // examine conditions relative to previously encountered loop variables. I'm |
| // uncertain if this would be helpful. |
| |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StorePtr ifStore = Store::make( |
| c, |
| {0}, |
| IfThenElse::make( |
| CompareSelect::make(y, 5, CompareSelectOperation::kLT), |
| Load::make(b, {x}), |
| Load::make(a, {x}))); |
| StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); |
| |
| stmt->accept(&analyzer); |
| |
| // C depends indirectly on A and B. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| } |
| } |
| |
| // Cutting a loop with single elem writes |
| TEST(MemDependency, MemDependencyCheckerCutLoop) { |
| BufHandle a("A", {10}, kInt); |
| BufHandle b("B", {10}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * B[x] = A[x]; |
| * } |
| * B[5] = 100; |
| */ |
| |
| // Cutting a loop with single element writes. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| StmtPtr stmt = Block::make( |
| {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), |
| Store::make(b, {5}, 100)}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // Output has 2 dependencies. |
| auto outputAccess = analyzer.output(b.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 2); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * B[x] = A[x]; |
| * } |
| * for (int x = 4; x < 7; x++) { |
| * B[x] = B[x] + 3; |
| * } |
| * B[5] = 100; |
| * B[6] = 101; |
| * B[7] = 102; |
| */ |
| |
| // Cutting a loop with a smaller loop but then totally overlap that second |
| // loop with one element writes. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| ForPtr firstLoop = |
| For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); |
| StorePtr secondStore = |
| Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); |
| ForPtr secondLoop = For::make(x, 4, 7, secondStore); |
| |
| StmtPtr stmt = Block::make( |
| {firstLoop, |
| secondLoop, |
| Store::make(b, {4}, 100), |
| Store::make(b, {5}, 101), |
| Store::make(b, {6}, 102)}); |
| |
| stmt->accept(&analyzer); |
| |
| // Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // Output has 4 dependencies. |
| auto outputAccess = analyzer.output(b.node()); |
| ASSERT_NE(outputAccess, nullptr); |
| ASSERT_EQ(outputAccess->dependencies().size(), 4); |
| |
| // Second loop depends on first loop. |
| ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); |
| |
| // Output does not depend on second loop or store. |
| ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); |
| ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); |
| } |
| } |
| |
| // Dynamic shapes (load in indices). |
| TEST(MemDependency, MemDependencyCheckerDynamicShapes) { |
| BufHandle a("A", {100}, kInt); |
| BufHandle b("B", {100}, kInt); |
| BufHandle c("C", {100}, kInt); |
| VarHandle x("x", kInt); |
| |
| using namespace analysis; |
| |
| auto CB = [](ExprHandle s, ExprHandle e) { |
| return Bound(s.node(), e.node()); |
| }; |
| |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| { |
| /* for (int x = 0; x < B[0]; x++) { |
| * C[x] = A[x]; |
| * } |
| */ |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make({For::make( |
| x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); |
| |
| stmt->accept(&analyzer); |
| |
| /* 0. Input: B[(0, 99)] - dependents: 2 |
| * 1. Input: A[(0, 99)] - dependents: 3 |
| * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 |
| * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 |
| * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 |
| * 5. Output: C[(0, 99)] - depends on: 4 |
| */ |
| |
| // Output dependent on A input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| // Also dependent on B input to determine the size of the region written. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 6); |
| |
| // The accesses in the loop depend on the load in the stop condition. |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| |
| // Make a load from B to compare against. |
| ExprHandle loadFromB = Load::make(b, {0}); |
| |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); |
| ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); |
| } |
| |
| { |
| /* for (int x = B[0]; x < B[1]; x++) { |
| * C[x] = A[x]; |
| * } |
| */ |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| Load::make(b, {0}), |
| Load::make(b, {1}), |
| Store::make(c, {x}, Load::make(a, {x})))}); |
| |
| stmt->accept(&analyzer); |
| |
| /* 0. Input: B[(0, 99)] - dependents: 2 3 |
| * 1. Input: A[(0, 99)] - dependents: 4 |
| * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 |
| * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 |
| * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 |
| * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 |
| * 6. Output: C[(0, 99)] - depends on: 5 |
| */ |
| |
| // Sanity check output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 7); |
| |
| // The accesses in the loop depend on the load in the start condition. |
| ASSERT_TRUE(history[5]->hasDependency(history[2])); |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| |
| // also the stop condition. |
| ASSERT_TRUE(history[5]->hasDependency(history[3])); |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| |
| // Make loads from B to compare against. |
| ExprHandle loadFromB0 = Load::make(b, {0}); |
| ExprHandle loadFromB1 = Load::make(b, {1}); |
| ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); |
| ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[x] = A[B[x]]; |
| * } |
| */ |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make({For::make( |
| x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); |
| |
| stmt->accept(&analyzer); |
| |
| /* 0. Input: B[(0, 99)] - dependents: 2 |
| * 1. Input: A[(0, 99)] - dependents: 3 |
| * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 |
| * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 |
| * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 |
| * 5. Output: C[(0, 99)] - depends on: 4 |
| */ |
| |
| // Sanity check output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 6); |
| |
| // The store depends on both loads, the load of A depends on the load of B. |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| |
| // The loads in the indices depend on the relevant input buffer. |
| ASSERT_TRUE(history[3]->hasDependency(history[1])); |
| ASSERT_TRUE(history[2]->hasDependency(history[0])); |
| |
| // The load from B has the loop bounds. |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); |
| |
| // The load from A has bounds B[0] to B[9]. |
| ExprHandle loadFromB0 = Load::make(b, {0}); |
| ExprHandle loadFromB9 = Load::make(b, {9}); |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[B[x]] = A[x]; |
| * } |
| */ |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make({For::make( |
| x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); |
| |
| stmt->accept(&analyzer); |
| |
| /* 0. Input: B[(0, 99)] - dependents: 3 |
| * 1. Input: A[(0, 99)] - dependents: 2 |
| * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 |
| * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 |
| * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 |
| * 5. Output: C[(0, 99)] - depends on: 4 |
| */ |
| // Sanity check output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 6); |
| |
| // The store depends on both loads, neither load is dependent. |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| |
| ASSERT_FALSE(history[3]->hasDependency(history[2])); |
| ASSERT_FALSE(history[2]->hasDependency(history[3])); |
| |
| // The loads each depend on their relevant input. (but accesses are in a |
| // different order than the last case). |
| ASSERT_TRUE(history[3]->hasDependency(history[0])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| |
| // The load from B has the loop bounds. |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); |
| |
| // And so does the load from A. |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * C[B[A[x]]] = x; |
| * } |
| */ |
| MemDependencyChecker analyzer({a, b}, {c}); |
| StmtPtr stmt = Block::make({For::make( |
| x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); |
| |
| stmt->accept(&analyzer); |
| |
| /* 0. Input: B[(0, 99)] - dependents: 3 |
| * 1. Input: A[(0, 99)] - dependents: 2 |
| * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 |
| * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 |
| * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 |
| * 5. Output: C[(0, 99)] - depends on: 4 |
| */ |
| |
| // Sanity check output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); |
| |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 6); |
| |
| // The store depends on both loads. |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| |
| // The outer load depends on the inner. |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| |
| // The loads each depend on their relevant input. (but accesses are in a |
| // different order than the last case). |
| ASSERT_TRUE(history[3]->hasDependency(history[0])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| |
| // The load from A has the loop bounds. |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); |
| // The load from B as bounds A[0] to A[9]. |
| ExprHandle loadFromA0 = Load::make(a, {0}); |
| ExprHandle loadFromA9 = Load::make(a, {9}); |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); |
| |
| // The store has bounds of B[A[0]] to B[A[9]]. |
| ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); |
| ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); |
| ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); |
| } |
| } |
| |
| // Verify multi dimensional bounds work. |
| TEST(MemDependency, MemDependencyCheckerMultiDim) { |
| int M = 10, N = 9, K = 12; |
| BufHandle a("A", {M, N, K}, kInt); |
| BufHandle b("B", {M, N, K}, kInt); |
| BufHandle c("C", {M, K}, kInt); |
| VarHandle x("x", kInt); |
| VarHandle y("y", kInt); |
| VarHandle z("z", kInt); |
| |
| using namespace analysis; |
| |
| auto CB = [](ExprHandle s, ExprHandle e) { |
| return Bound(s.node(), e.node()); |
| }; |
| |
| auto EQ = [](const IndexBounds& x, const IndexBounds& y) { |
| return indexBoundsEquals(x, y); |
| }; |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * for (int y = 0; y < 9; y++) { |
| * for (int z = 0; z < 12; z++) { |
| * B[x, y, z] = A[x, y, z]; |
| * } |
| * } |
| * } |
| */ |
| // Full range. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| 0, |
| M, |
| For::make( |
| y, |
| 0, |
| N, |
| For::make( |
| z, |
| 0, |
| K, |
| Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // 4 accesses: input, load, store, output. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 4); |
| |
| // Simple chain from input to output. |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| ASSERT_TRUE(history[1]->hasDependency(history[0])); |
| |
| ASSERT_TRUE( |
| EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); |
| ASSERT_TRUE( |
| EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 5; x++) { |
| * for (int y = 0; y < 5; y++) { |
| * for (int z = 0; z < 5; z++) { |
| * B[x, y, z] = A[x, y, z]; |
| * } |
| * } |
| * } |
| */ |
| // Partial range. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| 0, |
| 5, |
| For::make( |
| y, |
| 0, |
| 5, |
| For::make( |
| z, |
| 0, |
| 5, |
| Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // 4 accesses: input, load, store, output. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 4); |
| |
| // Simple chain from input to output. |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| ASSERT_TRUE(history[1]->hasDependency(history[0])); |
| |
| ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); |
| ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * for (int y = 0; y < 12; y++) { |
| * B[x, 0, y] = A[x, 0, y]; |
| * } |
| * } |
| */ |
| |
| // Partial loops. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| 0, |
| N, |
| For::make( |
| y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // 4 accesses: input, load, store, output. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 4); |
| |
| // Simple chain from input to output. |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| ASSERT_TRUE(history[1]->hasDependency(history[0])); |
| |
| ASSERT_TRUE( |
| EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); |
| ASSERT_TRUE( |
| EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 10; x++) { |
| * for (int y = 0; y < 100; y++) { |
| * for (int z = 0; z < 12; z++) { |
| * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); |
| * } |
| * } |
| * } |
| */ |
| |
| // Loops that don't correspond to an index, bufs with different |
| // dimensionality. |
| |
| MemDependencyChecker analyzer({a, c}, {b}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| 0, |
| M, |
| For::make( |
| y, |
| 0, |
| 100, |
| For::make( |
| z, |
| 0, |
| K, |
| Store::make( |
| b, |
| {x, 0, z}, |
| Add::make( |
| Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity test: Output depends on both inputs. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); |
| |
| // 6 accesses: 2 inputs, 2 loads, store, output. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 6); |
| |
| // Simple chain from input to output over the A buf. |
| // history[0] is the C input, history[3] is the load from C. |
| ASSERT_TRUE(history[5]->hasDependency(history[4])); |
| ASSERT_TRUE(history[4]->hasDependency(history[2])); |
| ASSERT_TRUE(history[2]->hasDependency(history[1])); |
| // The store also depends on the load from the C input. |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| ASSERT_TRUE(history[3]->hasDependency(history[0])); |
| |
| // A Buf accesses. |
| ASSERT_TRUE( |
| EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); |
| ASSERT_TRUE( |
| EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); |
| |
| // C buf access. |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); |
| } |
| |
| { |
| /* for (int x = 0; x < 9; x++) { |
| * for (int y = 0; y < 10; y++) { |
| * for (int z = 0; z < 12; z++) { |
| * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); |
| * } |
| * } |
| * } |
| */ |
| // Multi-dim reductions. |
| |
| MemDependencyChecker analyzer({a}, {b}); |
| StmtPtr stmt = Block::make({For::make( |
| x, |
| 0, |
| M, |
| For::make( |
| y, |
| 0, |
| N, |
| For::make( |
| z, |
| 0, |
| K, |
| Store::make( |
| b, |
| {x, 0, 0}, |
| Add::make( |
| Load::make(b, {x, y, z}), |
| Load::make(a, {x, y, z}))))))}); |
| |
| stmt->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); |
| |
| // 4 accesses: input, 2 loads, store, output. |
| auto history = analyzer.getHistory(); |
| ASSERT_EQ(history.size(), 5); |
| |
| // Simple chain from input to output. |
| ASSERT_TRUE(history[4]->hasDependency(history[3])); |
| ASSERT_TRUE(history[3]->hasDependency(history[2])); |
| ASSERT_TRUE(history[3]->hasDependency(history[1])); |
| ASSERT_TRUE(history[2]->hasDependency(history[0])); |
| |
| // The load from B depends on the store to B. |
| ASSERT_TRUE(history[1]->hasDependency(history[3])); |
| |
| ASSERT_TRUE( |
| EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); |
| ASSERT_TRUE( |
| EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); |
| ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); |
| } |
| } |
| |
| // Various tests using the external Compute/Reduce API. |
| TEST(MemDependency, MemDependencyCheckerComputeAPI) { |
| using namespace analysis; |
| |
| /* for (int m = 0; m < 4; m++) { |
| * for (int n = 0; n < 5; n++) { |
| * for (int k = 0; k < 6; k++) { |
| * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); |
| * } |
| * } |
| * } |
| * for (int m_1 = 0; m_1 < 4; m_1++) { |
| * for (int n_1 = 0; n_1 < 5; n_1++) { |
| * for (int k_1 = 0; k_1 < 6; k_1++) { |
| * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); |
| * } |
| * } |
| * } |
| */ |
| |
| // Can determine if 2 loops created by Compute are dependent. |
| BufHandle a_buf("a", {4, 5}, kFloat); |
| BufHandle b_buf("b", {5, 6}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| Tensor d = Compute( |
| "d", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c.load(m, n, k) + 1; |
| }); |
| |
| LoopNest l({d}, {c, d}); |
| |
| MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); |
| |
| l.root_stmt()->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); |
| |
| // Second loop depends on first loop. |
| auto c_loop = l.getLoopStmtsFor(c)[0]; |
| auto d_loop = l.getLoopStmtsFor(d)[0]; |
| ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); |
| } |
| |
| TEST(MemDependency, MemDependencyCheckerComputeInline) { |
| using namespace analysis; |
| |
| /* for (int m = 0; m < 4; m++) { |
| * for (int n = 0; n < 5; n++) { |
| * for (int k = 0; k < 6; k++) { |
| * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); |
| * } |
| * } |
| * } |
| */ |
| |
| // Check inlining affects the number of accesses returned. |
| |
| BufHandle a_buf("a", {4, 5}, kFloat); |
| BufHandle b_buf("b", {5, 6}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| Tensor d = Compute( |
| "d", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return c.load(m, n, k) + 1; |
| }); |
| |
| LoopNest l({d}, {c, d}); |
| l.computeInline(c.buf()); |
| |
| MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); |
| l.root_stmt()->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); |
| |
| // broadcast_add tensor should not appear in trace at all. |
| for (auto& wi : analyzer.getHistory()) { |
| ASSERT_NE(wi->var(), c.buf()->base_handle()); |
| } |
| } |
| |
| TEST(MemDependency, MemDependencyCheckerComputeSplit) { |
| using namespace analysis; |
| // Split an axis, so the number of loops != the number of dimensions. |
| |
| BufHandle a_buf("a", {4, 5}, kFloat); |
| BufHandle b_buf("b", {5, 6}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| |
| LoopNest l({c}); |
| |
| MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); |
| l.root_stmt()->accept(&analyzer_before); |
| |
| l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); |
| |
| MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); |
| StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
| stmt->accept(&analyzer_after); |
| |
| // Splitting should not change accesses at all. |
| auto history_before = analyzer_before.getHistory(); |
| auto history_after = analyzer_after.getHistory(); |
| |
| ASSERT_EQ(history_before.size(), history_after.size()); |
| |
| for (size_t i = 0; i < history_before.size(); ++i) { |
| ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); |
| ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); |
| ASSERT_EQ( |
| history_before[i]->bounds().size(), history_after[i]->bounds().size()); |
| ASSERT_TRUE(indexBoundsEquals( |
| history_before[i]->bounds(), history_after[i]->bounds())); |
| ASSERT_EQ( |
| history_before[i]->dependencies().size(), |
| history_after[i]->dependencies().size()); |
| ASSERT_EQ( |
| history_before[i]->dependents().size(), |
| history_after[i]->dependents().size()); |
| } |
| } |
| |
| TEST(MemDependency, MemDependencyCheckerComputeReorder) { |
| using namespace analysis; |
| // Reorder an axis, so the loop order doesn't match the indexing order. |
| |
| BufHandle a_buf("a", {4, 5}, kFloat); |
| BufHandle b_buf("b", {5, 6}, kFloat); |
| Tensor c = Compute( |
| "broadcast_add", |
| {4, 5, 6}, |
| [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { |
| return a_buf.load(m, n) + b_buf.load(n, k); |
| }); |
| |
| LoopNest l({c}); |
| |
| MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); |
| l.root_stmt()->accept(&analyzer_before); |
| |
| auto loops = l.getLoopStmtsFor(c); |
| l.reorderAxis(loops[0], loops[1]); |
| |
| MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); |
| StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); |
| stmt->accept(&analyzer_after); |
| |
| // Reordering should not change accesses at all. |
| auto history_before = analyzer_before.getHistory(); |
| auto history_after = analyzer_after.getHistory(); |
| |
| ASSERT_EQ(history_before.size(), history_after.size()); |
| |
| for (size_t i = 0; i < history_before.size(); ++i) { |
| ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); |
| ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); |
| ASSERT_EQ( |
| history_before[i]->bounds().size(), history_after[i]->bounds().size()); |
| ASSERT_TRUE(indexBoundsEquals( |
| history_before[i]->bounds(), history_after[i]->bounds())); |
| ASSERT_EQ( |
| history_before[i]->dependencies().size(), |
| history_after[i]->dependencies().size()); |
| ASSERT_EQ( |
| history_before[i]->dependents().size(), |
| history_after[i]->dependents().size()); |
| } |
| } |
| |
| TEST(MemDependency, MemDependencyCheckerComputeReduce) { |
| using namespace analysis; |
| /* for (int l2 = 0; l2 < 2; l2++) { |
| * for (int n1 = 0; n1 < 3; n1++) { |
| * for (int m1 = 0; m1 < 6; m1++) { |
| * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); |
| * } |
| * } |
| * } |
| * for (int l1 = 0; l1 < 2; l1++) { |
| * sum[l1] = float(0); |
| * for (int n1_1 = 0; n1_1 < 3; n1_1++) { |
| * for (int m1_1 = 0; m1_1 < 6; m1_1++) { |
| * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), |
| * out_args={l1}, reduce_args={n1, m1}); |
| * } |
| * } |
| * } |
| */ |
| |
| // Can determine dependencies of a Reduction. |
| |
| BufHandle a("a", {2, 3, 6}, kFloat); |
| BufHandle b("b", {2, 3, 6}, kFloat); |
| |
| Tensor c = Compute( |
| "scale", |
| {2, 3, 6}, |
| [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { |
| return b.load(l, n, m) * a.load(l, n, m); |
| }); |
| Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); |
| LoopNest l({d}, {c, d}); |
| |
| MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); |
| |
| l.root_stmt()->accept(&analyzer); |
| |
| // Sanity test: Output depends on input. |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); |
| |
| // Second loop depends on first loop. |
| auto c_loop = l.getLoopStmtsFor(c)[0]; |
| auto d_loop = l.getLoopStmtsFor(d)[0]; |
| ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); |
| |
| // Reduction depends on both inputs. |
| auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt()); |
| ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); |
| ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); |
| } |
| |
| TEST(MemDependency, MemDependencyCheckerComputeGEMM) { |
| int M = 1024; |
| int N = 1024; |
| int K = 2048; |
| using namespace analysis; |
| |
| BufHandle AP("A", {M, K}, kFloat); |
| BufHandle BP("B", {K, N}, kFloat); |
| Tensor CT = Reduce( |
| "gemm", |
| {M, N}, |
| Sum(), |
| [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
| return AP.load(m, k) * BP.load(k, n); |
| }, |
| {K}); |
| LoopNest loop({CT}); |
| |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| ForPtr m = loops[0]; |
| loop.splitWithMask(m, 4); |
| } |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| ForPtr n = loops[2]; |
| loop.splitWithMask(n, 16); |
| } |
| // mo, mi, no, ni, k -> |
| // mo, no, mi, ni, k |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| ForPtr mi = loops[1]; |
| ForPtr no = loops[2]; |
| loop.reorderAxis(mi, no); |
| } |
| // mo, no, mi, ni, k -> |
| // mo, no, mi, k, ni |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| ForPtr ni = loops[3]; |
| ForPtr k = loops[4]; |
| loop.reorderAxis(ni, k); |
| } |
| // mo, no, mi, k, ni -> |
| // mo, no, k, mi, ni |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| ForPtr mi = loops[2]; |
| ForPtr k = loops[3]; |
| loop.reorderAxis(mi, k); |
| } |
| { |
| auto const& loops = loop.getLoopStmtsFor(CT); |
| loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); |
| } |
| |
| MemDependencyChecker analyzer_unlowered( |
| loop.getInputBufs(), loop.getOutputBufs()); |
| |
| MemDependencyChecker analyzer_lowered( |
| loop.getInputBufs(), loop.getOutputBufs()); |
| |
| // Test both unlowered and lowered form. |
| { |
| StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); |
| stmt->accept(&analyzer_unlowered); |
| |
| // Outputs depend on inputs. |
| ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); |
| ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); |
| |
| // The last write to gemm should cover the total bound of the output. |
| std::shared_ptr<AccessInfo> outputAccess = |
| analyzer_unlowered.output(CT.buf()); |
| // A single dependency. |
| ASSERT_EQ(outputAccess->dependencies().size(), 1); |
| |
| // dependencies is a set with 1 element, so can just deref begin(). |
| std::shared_ptr<AccessInfo> gemmStore = |
| outputAccess->dependencies().begin()->second; |
| // Check its a store. |
| ASSERT_EQ(gemmStore->type(), AccessType::Store); |
| |
| ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); |
| |
| // Likewise the first read from each input cover the entire range of the |
| // input. |
| auto aInput = analyzer_unlowered.input(AP.node()); |
| auto bInput = analyzer_unlowered.input(BP.node()); |
| |
| // A single dependent each. |
| ASSERT_EQ(aInput->dependents().size(), 1); |
| ASSERT_EQ(bInput->dependents().size(), 1); |
| |
| // They're both loads. |
| std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second; |
| std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second; |
| ASSERT_EQ(aLoad->type(), AccessType::Load); |
| ASSERT_EQ(bLoad->type(), AccessType::Load); |
| |
| ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); |
| ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); |
| } |
| |
| loop.prepareForCodegen(); |
| SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); |
| |
| // now check lowered dependency graph. |
| { |
| StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); |
| stmt->accept(&analyzer_lowered); |
| |
| // Lowering will change the dimensionality of all bounds due to index |
| // flattening and will insert Allocates and Frees. |
| |
| auto history_before = analyzer_unlowered.getHistory(); |
| auto history_after = analyzer_lowered.getHistory(); |
| |
| ASSERT_EQ(history_before.size() + 2, history_after.size()); |
| |
| // Filter out the alloc/free; |
| auto isAllocFree = [](const auto& info) { |
| return info->type() == AccessType::Alloc || |
| info->type() == AccessType::Free; |
| }; |
| history_after.erase( |
| std::remove_if(history_after.begin(), history_after.end(), isAllocFree), |
| history_after.end()); |
| |
| ASSERT_EQ(history_before.size(), history_after.size()); |
| |
| for (size_t i = 0; i < history_before.size(); ++i) { |
| ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); |
| ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); |
| |
| if (history_before[i]->dependencies().size() != |
| history_after[i]->dependencies().size()) { |
| // Must depend on an Alloc. |
| ASSERT_TRUE(std::any_of( |
| history_after[i]->dependencies().begin(), |
| history_after[i]->dependencies().end(), |
| [](const auto& pair) { |
| return pair.second->type() == AccessType::Alloc; |
| })); |
| |
| ASSERT_EQ( |
| history_before[i]->dependencies().size() + 1, |
| history_after[i]->dependencies().size()); |
| } |
| |
| if (history_before[i]->dependents().size() != |
| history_after[i]->dependents().size()) { |
| // Must depend on an Free. |
| ASSERT_TRUE(std::any_of( |
| history_after[i]->dependents().begin(), |
| history_after[i]->dependents().end(), |
| [](const auto& pair) { |
| return pair.second->type() == AccessType::Free; |
| })); |
| |
| ASSERT_EQ( |
| history_before[i]->dependents().size() + 1, |
| history_after[i]->dependents().size()); |
| } |
| |
| // Inputs and outputs are not flattened, only accesses. |
| if (history_before[i]->type() == AccessType::Input || |
| history_before[i]->type() == AccessType::Output) { |
| ASSERT_EQ( |
| history_before[i]->bounds().size(), |
| history_after[i]->bounds().size()); |
| ASSERT_TRUE(indexBoundsEquals( |
| history_before[i]->bounds(), history_after[i]->bounds())); |
| } else { |
| ASSERT_EQ(history_after[i]->bounds().size(), 1); |
| ExprPtr flat_bounds = alloc<IntImm>(1); |
| |
| for (auto& b : history_before[i]->bounds()) { |
| flat_bounds = |
| alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1))); |
| |
| // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) |
| ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); |
| } |
| |
| flat_bounds = IRSimplifier::simplify(flat_bounds); |
| ExprPtr after_bounds = IRSimplifier::simplify( |
| alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1))); |
| ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); |
| } |
| } |
| } |
| } |
| |
| } // namespace jit |
| } // namespace torch |