| # Owner(s): ["NNC"] |
| |
| import torch |
| import numpy as np |
| import torch._C._te as te |
| |
| from torch.testing._internal.common_utils import run_tests |
| from torch.testing._internal.jit_utils import JitTestCase |
| import unittest |
| |
| LLVM_ENABLED = torch._C._llvm_enabled() |
| |
| |
| def construct_adder(n: int, dtype=torch.float32): |
| A = te.BufHandle("A", [n], dtype) |
| B = te.BufHandle("B", [n], dtype) |
| |
| def compute(i): |
| return A.load([i]) + B.load([i]) |
| |
| C = te.Compute("C", [n], compute) |
| |
| loopnest = te.LoopNest([C]) |
| loopnest.prepare_for_codegen() |
| stmt = te.simplify(loopnest.root_stmt()) |
| |
| return te.construct_codegen("ir_eval", stmt, [A, B, C]) |
| |
| |
| class TestTensorExprPyBind(JitTestCase): |
| def test_simple_sum(self): |
| n = 32 |
| cg = construct_adder(n) |
| |
| tA = torch.randn(n) |
| tB = torch.randn(n) |
| tC = torch.empty(n) |
| cg.call([tA, tB, tC]) |
| torch.testing.assert_close(tA + tB, tC) |
| |
| def test_call_raw(self): |
| n = 16 |
| cg = construct_adder(n, dtype=torch.float64) |
| |
| tA = torch.randn(n, dtype=torch.float64) |
| tB = torch.randn(n, dtype=torch.float64) |
| tC = torch.empty(n, dtype=torch.float64) |
| cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()]) |
| torch.testing.assert_close(tA + tB, tC) |
| |
| def test_external_calls(self): |
| dtype = torch.float32 |
| |
| A = te.BufHandle("A", [1, 4], dtype) |
| B = te.BufHandle("B", [4, 1], dtype) |
| C = te.BufHandle("C", [1, 1], dtype) |
| |
| s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) |
| |
| loopnest = te.LoopNest(s, [C]) |
| loopnest.prepare_for_codegen() |
| codegen = te.construct_codegen("ir_eval", s, [A, B, C]) |
| |
| tA = torch.ones(1, 4) |
| tB = torch.ones(4, 1) |
| tC = torch.empty(1, 1) |
| codegen.call([tA, tB, tC]) |
| torch.testing.assert_close(torch.matmul(tA, tB), tC) |
| |
| def test_dynamic_shape(self): |
| dN = te.VarHandle(torch.int32) |
| A = te.BufHandle([dN], torch.float64) |
| B = te.BufHandle([dN], torch.float64) |
| |
| def compute(i): |
| return A.load(i) - B.load(i) |
| |
| C = te.Compute("C", [dN], compute) |
| |
| loopnest = te.LoopNest([C]) |
| loopnest.prepare_for_codegen() |
| |
| cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN]) |
| |
| def test_with_shape(n): |
| tA = torch.randn(n, dtype=torch.double) |
| tB = torch.randn(n, dtype=torch.double) |
| tC = torch.empty(n, dtype=torch.double) |
| cg.call([tA, tB, tC, n]) |
| torch.testing.assert_close(tA - tB, tC) |
| |
| test_with_shape(8) |
| test_with_shape(31) |
| |
| def test_dynamic_shape_2d(self): |
| dN = te.VarHandle(torch.int32) |
| dM = te.VarHandle(torch.int32) |
| A = te.BufHandle([dN, dM], torch.float64) |
| B = te.BufHandle([dN, dM], torch.float64) |
| |
| def compute(i, j): |
| return A.load([i, j]) - B.load([i, j]) |
| |
| C = te.Compute("C", [dN, dM], compute) |
| |
| loopnest = te.LoopNest([C]) |
| loopnest.prepare_for_codegen() |
| |
| cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM]) |
| |
| def test_with_shape(n, m): |
| tA = torch.randn(n, m, dtype=torch.double) |
| tB = torch.randn(n, m, dtype=torch.double) |
| tC = torch.empty(n, m, dtype=torch.double) |
| cg.call([tA, tB, tC, n, m]) |
| torch.testing.assert_close(tA - tB, tC) |
| |
| test_with_shape(2, 4) |
| test_with_shape(5, 3) |
| |
| def test_dtype_error(self): |
| te.BufHandle("a", [1], torch.float32) # ok |
| self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55")) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_tensor_inputs(self): |
| def f(a, b, c): |
| return a + b + c |
| |
| device, size = "cpu", (4, 4) |
| x = torch.rand(size, device=device) |
| y = torch.rand(size, device=device) |
| z = torch.rand(size, device=device) |
| |
| graph_str = """ |
| graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), |
| %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu), |
| %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)): |
| %6 : int = prim::Constant[value=1]() |
| %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6) |
| %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6) |
| return (%3) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x, y, z)) |
| res2 = kernel.fallback((x, y, z)) |
| correct = f(x, y, z) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_scalar_inputs(self): |
| def f(a, b, c): |
| return a + b + c |
| |
| x = torch.tensor(0.1, dtype=torch.float, device="cpu") |
| y = torch.tensor(0.6, dtype=torch.float, device="cpu") |
| z = torch.tensor(0.7, dtype=torch.float, device="cpu") |
| |
| graph_str = """ |
| graph(%a.1 : Float(requires_grad=0, device=cpu), |
| %b.1 : Float(requires_grad=0, device=cpu), |
| %c.1 : Float(requires_grad=0, device=cpu)): |
| %3 : int = prim::Constant[value=1]() |
| %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3) |
| %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3) |
| return (%9) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x, y, z)) |
| res2 = kernel.fallback((x, y, z)) |
| correct = f(x, y, z) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_shape_prop(self): |
| device, size = "cpu", (4, 4) |
| x = torch.rand(size, device=device) |
| y = torch.rand(size, device=device) |
| |
| graph_str = """ |
| graph(%a : Tensor, %b : Tensor): |
| %c : Tensor = aten::mul(%a, %b) |
| return (%c) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| exception_thrown = False |
| try: |
| kernel = te.TensorExprKernel(graph) |
| except RuntimeError: |
| # Graph doesn't have shape info for inputs => compilation should |
| # fail |
| exception_thrown = True |
| assert exception_thrown |
| |
| # Inject shape info and try compiling again |
| example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] |
| torch._C._te.annotate_input_shapes(graph, example_inputs) |
| torch._C._jit_pass_propagate_shapes_on_graph(graph) |
| |
| # Now compilation should pass |
| kernel = te.TensorExprKernel(graph) |
| |
| res = kernel.run((x, y)) |
| correct = torch.mul(x, y) |
| np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_shape_prop_module(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, x, y): |
| return x * x + y |
| |
| graph = torch.jit.script(TestModule()).graph |
| |
| # Try compiling the graph as-is. It should fail because it doesn't have |
| # shape info. |
| exception_thrown = False |
| try: |
| kernel = te.TensorExprKernel(graph) |
| except RuntimeError: |
| exception_thrown = True |
| assert exception_thrown |
| |
| # Try injecting shape info for graph inputs |
| example_inputs = [torch.rand(4, 4), torch.rand(4, 4)] |
| |
| exception_thrown = False |
| try: |
| torch._C._te.annotate_input_shapes(graph, example_inputs) |
| except RuntimeError: |
| # Graph has a 'self' argument for which we can't set shapes |
| exception_thrown = True |
| assert exception_thrown |
| |
| # Remove 'self' argument and try annotating shapes one more time |
| torch._C._te.remove_unused_self_argument(graph) |
| |
| # Inject shape info and try compiling again |
| torch._C._te.annotate_input_shapes(graph, example_inputs) |
| torch._C._jit_pass_propagate_shapes_on_graph(graph) |
| |
| # Now compilation should pass |
| kernel = te.TensorExprKernel(graph) |
| |
| device, size = "cpu", (4, 4) |
| x = torch.rand(size, device=device) |
| y = torch.rand(size, device=device) |
| |
| res = kernel.run((x, y)) |
| correct = TestModule().forward(x, y) |
| np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_t(self): |
| def f(a): |
| return a.t() |
| |
| device, size = "cpu", (3, 4) |
| x = torch.rand(size, device=device) |
| |
| graph_str = """ |
| graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): |
| %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1) |
| return (%3) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x,)) |
| res2 = kernel.fallback((x,)) |
| correct = f(x) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_transpose(self): |
| def f(a): |
| return a.transpose(-1, -2) |
| |
| device, size = "cpu", (3, 4) |
| x = torch.rand(size, device=device) |
| |
| graph_str = """ |
| graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): |
| %2 : int = prim::Constant[value=-1]() |
| %3 : int = prim::Constant[value=-2]() |
| %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3) |
| return (%4) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x,)) |
| res2 = kernel.fallback((x,)) |
| correct = f(x) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_permute(self): |
| def f(a): |
| return a.permute([2, 1, 0]) |
| |
| device, size = "cpu", (3, 4, 5) |
| x = torch.rand(size, device=device) |
| |
| graph_str = """ |
| graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)): |
| %1 : int = prim::Constant[value=2]() |
| %2 : int = prim::Constant[value=1]() |
| %3 : int = prim::Constant[value=0]() |
| %4 : int[] = prim::ListConstruct(%1, %2, %3) |
| %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4) |
| return (%5) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x,)) |
| res2 = kernel.fallback((x,)) |
| correct = f(x) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_custom_lowering(self): |
| def f(a): |
| return a.nan_to_num() |
| |
| device = "cpu" |
| x = torch.ones((2, 2), device=device) |
| x[0, 0] = x[1, 1] = torch.nan |
| graph_str = """ |
| graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): |
| %none : NoneType = prim::Constant() |
| %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) |
| return (%y) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| def my_custom_lowering(inputs, out_shape, out_stride, out_type, device): |
| def compute(idxs): |
| load = inputs[0].as_buf().load(idxs) |
| return te.ifThenElse( |
| te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load |
| ) |
| |
| return te.Compute2("custom_nan_to_num", out_shape, compute) |
| |
| kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering}) |
| res1 = kernel.run((x,)) |
| res2 = kernel.fallback((x,)) |
| correct = f(x) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_kernel_with_expand(self): |
| def f(a): |
| return a.expand((2, 3, 4)) |
| |
| device = "cpu" |
| x = torch.rand((1, 3, 1), device=device) |
| graph_str = """ |
| graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)): |
| %1 : int = prim::Constant[value=2]() |
| %2 : int = prim::Constant[value=3]() |
| %3 : int = prim::Constant[value=4]() |
| %4 : int[] = prim::ListConstruct(%1, %2, %3) |
| %5 : bool = prim::Constant[value=0]() |
| %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5) |
| return (%6) |
| """ |
| graph = torch._C.parse_ir(graph_str) |
| |
| kernel = te.TensorExprKernel(graph) |
| res1 = kernel.run((x,)) |
| res2 = kernel.fallback((x,)) |
| correct = f(x) |
| np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) |
| np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") |
| def test_alloc_in_loop(self): |
| a, tmp, b = ( |
| te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"] |
| ) |
| body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))]) |
| for _ in range(4): |
| i = te.VarHandle("i", torch.int32) |
| body = te.For.make(i, 0, 100, body) |
| nest = te.LoopNest(body, [b]) |
| nest.prepare_for_codegen() |
| f = te.construct_codegen("llvm", nest.simplify(), [a, b]) |
| ta, tb = (torch.ones(1) for _ in range(2)) |
| f.call([ta.data_ptr(), tb.data_ptr()]) |
| |
| |
| class TestExprHandlePyBind(JitTestCase): |
| def test_unary_ops(self): |
| unary_operators = { |
| torch.sin: torch._C._te.sin, |
| torch.cos: torch._C._te.cos, |
| torch.tan: torch._C._te.tan, |
| torch.asin: torch._C._te.asin, |
| torch.acos: torch._C._te.acos, |
| torch.atan: torch._C._te.atan, |
| torch.sinh: torch._C._te.sinh, |
| torch.cosh: torch._C._te.cosh, |
| torch.tanh: torch._C._te.tanh, |
| torch.sigmoid: torch._C._te.sigmoid, |
| torch.exp: torch._C._te.exp, |
| torch.expm1: torch._C._te.expm1, |
| torch.abs: torch._C._te.abs, |
| torch.log: torch._C._te.log, |
| torch.log2: torch._C._te.log2, |
| torch.log10: torch._C._te.log10, |
| torch.log1p: torch._C._te.log1p, |
| torch.erf: torch._C._te.erf, |
| torch.erfc: torch._C._te.erfc, |
| torch.sqrt: torch._C._te.sqrt, |
| torch.rsqrt: torch._C._te.rsqrt, |
| torch.ceil: torch._C._te.ceil, |
| torch.floor: torch._C._te.floor, |
| torch.round: torch._C._te.round, |
| torch.trunc: torch._C._te.trunc, |
| torch.lgamma: torch._C._te.lgamma, |
| torch.frac: torch._C._te.frac, |
| } |
| |
| def construct_te_fn(op, n: int, dtype=torch.float32): |
| A = torch._C._te.BufHandle("A", [n], dtype) |
| |
| def compute(i): |
| return op(A.load([i])) |
| |
| C = te.Compute("C", [n], compute) |
| |
| loopnest = te.LoopNest([C]) |
| loopnest.prepare_for_codegen() |
| stmt = te.simplify(loopnest.root_stmt()) |
| |
| return te.construct_codegen("ir_eval", stmt, [A, C]) |
| |
| n = 10 |
| a = torch.rand(n) |
| for torch_op, te_op in unary_operators.items(): |
| ref = torch_op(a) |
| |
| te_fn = construct_te_fn(te_op, n, torch.float32) |
| res = torch.empty(n) |
| te_fn.call([a, res]) |
| assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |