| # Owner(s): ["NNC"] |
| |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| import unittest |
| import itertools |
| |
| from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo |
| |
| from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions |
| |
| LLVM_ENABLED = torch._C._llvm_enabled() |
| |
| class BaseTestClass(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.tensorexpr_options = TensorExprTestOptions() |
| self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] |
| self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32] |
| |
| def tearDown(self): |
| self.tensorexpr_options.restore() |
| super().tearDown() |
| |
| def assertLastGraphAllFused(self): |
| self.assertAllFused(torch.jit.last_executed_optimized_graph()) |
| |
| |
| def warmup_and_run_forward(f, *args): |
| for _ in range(torch._C._jit_get_num_profiled_runs() + 1): |
| results = f(*args) |
| return results |
| |
| |
| @skipIfTorchDynamo() |
| class TestTensorExprFuser(BaseTestClass): |
| def test_easy(self): |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| return aaa |
| |
| traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) |
| |
| a = torch.rand(1024) |
| b = torch.rand(1024) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) |
| |
| def test_three_arg(self): |
| def easy(x, y, z): |
| aaa = torch.add(x, y) |
| bbb = torch.add(aaa, z) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) |
| ) |
| |
| a = torch.rand(1024) |
| b = torch.rand(1024) |
| c = torch.rand(1024) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = a.numpy() + b.numpy() + c.numpy() |
| np.testing.assert_allclose(npr, x.numpy()) |
| |
| def test_four_arg(self): |
| def run_addcmul(x, y, z, w): |
| c = torch.addcmul(torch.add(x, y), z, w) |
| return c |
| |
| for dev in self.devices: |
| rand_a = torch.rand(1024, dtype=torch.float, device=dev) |
| rand_b = torch.rand(1024, dtype=torch.float, device=dev) |
| rand_c = torch.rand(1024, dtype=torch.float, device=dev) |
| rand_d = torch.rand(1024, dtype=torch.float, device=dev) |
| |
| traced = torch.jit.trace( |
| run_addcmul, |
| ( |
| torch.zeros(1024, dtype=torch.float, device=dev), |
| torch.zeros(1024, dtype=torch.float, device=dev), |
| torch.zeros(1024, dtype=torch.float, device=dev), |
| torch.zeros(1024, dtype=torch.float, device=dev), |
| ), |
| ) |
| |
| x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) |
| self.assertLastGraphAllFused() |
| y = run_addcmul(rand_a, rand_b, rand_c, rand_d) |
| np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) |
| |
| def test_three_arg2(self): |
| for device in self.devices: |
| def test(x, y, z): |
| aaa = torch.add(x, y) |
| bbb = torch.add(aaa, z) |
| return bbb |
| |
| M = 32 |
| N = 32 |
| traced = torch.jit.trace( |
| test, |
| ( |
| torch.rand(M, N, device=device), |
| torch.rand(M, N, device=device), |
| torch.rand(M, N, device=device), |
| ), |
| ) |
| |
| a = torch.rand(M, N, device=device) |
| b = torch.rand(M, N, device=device) |
| c = torch.rand(M, N, device=device) |
| x = traced(a, b, c) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() |
| np.testing.assert_allclose(npr, x.cpu().numpy()) |
| |
| def test_broadcast3(self): |
| for device in self.devices: |
| def test_body(M, N, L, K): |
| def test(x, y, z): |
| v1 = torch.add(x, y) |
| v2 = torch.add(v1, z) |
| return v2 |
| |
| a_shape = [M, N] |
| b_shape = [L, M, 1] |
| c_shape = [K, L, 1, 1] |
| traced = torch.jit.trace( |
| test, |
| ( |
| torch.rand(*a_shape, device=device), |
| torch.rand(*b_shape, device=device), |
| torch.rand(*c_shape, device=device), |
| ), |
| ) |
| |
| a = torch.rand(*a_shape, device=device) |
| b = torch.rand(*b_shape, device=device) |
| c = torch.rand(*c_shape, device=device) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() |
| np.testing.assert_allclose(npr, x.cpu().numpy()) |
| |
| test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]] |
| for test_config in test_configs: |
| test_body(*test_config) |
| |
| def test_all_combos(self): |
| def easy(x, y, z): |
| a = torch.add(x, y) |
| b = torch.add(a, z) |
| c = torch.add(x, b) |
| d = torch.add(c, a) |
| return d |
| |
| def np_easy(x, y, z): |
| a = x + y |
| b = a + z |
| c = x + b |
| d = c + a |
| return d |
| |
| traced = torch.jit.trace( |
| easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) |
| ) |
| |
| a = torch.rand(1024) |
| b = torch.rand(1024) |
| c = torch.rand(1024) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = np_easy(a.numpy(), b.numpy(), c.numpy()) |
| np.testing.assert_allclose(npr, x.numpy()) |
| |
| def test_rank_two(self): |
| def easy(x, y, z): |
| a = torch.add(x, y) |
| b = torch.add(a, z) |
| c = torch.add(x, b) |
| d = torch.add(c, a) |
| return d |
| |
| def np_easy(x, y, z): |
| a = x + y |
| b = a + z |
| c = x + b |
| d = c + a |
| return d |
| |
| shape = 32, 32 |
| traced = torch.jit.trace( |
| easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape)) |
| ) |
| |
| a = torch.rand(shape) |
| b = torch.rand(shape) |
| c = torch.rand(shape) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = np_easy(a.numpy(), b.numpy(), c.numpy()) |
| np.testing.assert_allclose(npr, x.numpy()) |
| |
| def test_broadcast(self): |
| def easy(x, y, z): |
| a = torch.add(x, y) |
| b = torch.add(a, z) |
| return b |
| |
| def np_easy(x, y, z): |
| a = x + y |
| b = a + z |
| return b |
| |
| N = 32 |
| traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N))) |
| |
| a = torch.rand(N, N) |
| b = torch.rand(N) |
| c = torch.rand(N, N) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| npr = np_easy(a.numpy(), b.numpy(), c.numpy()) |
| np.testing.assert_allclose(npr, x.numpy()) |
| |
| def test_broadcast_2(self): |
| zero = torch.tensor([0.0], dtype=torch.float) |
| |
| def foo(x, y, z): |
| aaa = torch.add(x, y) |
| bbb = torch.add(zero, aaa) |
| return torch.add(bbb, z) |
| |
| def foo_np(x, y, z): |
| a = x + y |
| b = zero.numpy() + a |
| return b + z |
| |
| x = torch.rand(3, 4) |
| y = torch.ones(3, 1) |
| z = torch.rand(4) |
| traced = torch.jit.trace(foo, (x, y, z)) |
| |
| r = warmup_and_run_forward(traced, x, y, z) |
| self.assertLastGraphAllFused() |
| |
| rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) |
| np.testing.assert_allclose(r, rnp) |
| |
| def test_broadcast_big2(self): |
| zero = torch.tensor([0.0], dtype=torch.float) |
| |
| def foo(x, y, z): |
| aaa = torch.add(x, y) |
| bbb = torch.add(zero, aaa) |
| return torch.add(bbb, z) |
| |
| def foo_np(x, y, z): |
| a = x + y |
| b = zero.numpy() + a |
| return b + z |
| |
| x = torch.rand(32, 1024) |
| y = torch.ones(32, 1) |
| z = torch.rand(1024) |
| traced = torch.jit.trace(foo, (x, y, z)) |
| |
| r = warmup_and_run_forward(traced, x, y, z) |
| self.assertLastGraphAllFused() |
| rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) |
| np.testing.assert_allclose(r, rnp) |
| |
| def test_alpha(self): |
| def alpha(x): |
| aaa = torch.add(x, x, alpha=2.0) |
| return aaa |
| |
| traced = torch.jit.trace(alpha, (torch.tensor([1.0]))) |
| |
| a = torch.tensor([1.0]) |
| x = traced(a) |
| np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) |
| |
| @suppress_warnings |
| def test_constant(self): |
| def constant(x): |
| bbb = torch.tensor([1.0]) |
| aaa = torch.add(x, bbb) |
| return aaa |
| |
| traced = torch.jit.trace(constant, (torch.tensor([1.0]))) |
| |
| a = torch.tensor([1.0]) |
| x = warmup_and_run_forward(traced, a) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) |
| |
| def test_add_sub(self): |
| def easy(x, y, z): |
| aaa = torch.add(x, y) |
| bbb = torch.sub(aaa, z) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) |
| ) |
| |
| a = torch.rand(1024) |
| b = torch.rand(1024) |
| c = torch.rand(1024) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) |
| |
| def test_promotion(self): |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| return aaa |
| |
| traced = torch.jit.trace( |
| easy, |
| (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)), |
| ) |
| |
| a = torch.zeros(1024, dtype=torch.int32) |
| b = torch.rand(1024, dtype=torch.float32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) |
| |
| def test_double(self): |
| TENSOR_LEN = 8 |
| |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| bbb = torch.mul(aaa, y) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, |
| (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)), |
| ) |
| |
| a = torch.rand(TENSOR_LEN, dtype=torch.double) |
| b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) |
| |
| def test_short(self): |
| TENSOR_LEN = 8 |
| |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| bbb = torch.mul(aaa, y) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, |
| (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16), |
| torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)), |
| ) |
| |
| a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) |
| b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) |
| |
| def test_char(self): |
| TENSOR_LEN = 8 |
| |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| bbb = torch.mul(aaa, y) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, |
| (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), |
| torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), |
| ) |
| |
| a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) |
| b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) |
| |
| def test_int64_promotion(self): |
| TENSOR_LEN = 8 |
| |
| def easy(x, y): |
| aaa = torch.add(x, y) |
| bbb = torch.mul(aaa, y) |
| return bbb |
| |
| traced = torch.jit.trace( |
| easy, |
| (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), |
| torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)), |
| ) |
| |
| a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) |
| b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) |
| |
| def test_eq(self): |
| def easy(x, y): |
| c = torch.eq(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) |
| a = torch.zeros(1024, dtype=torch.int32) |
| b = torch.zeros(1024, dtype=torch.int32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.ones(1024), x.numpy()) |
| |
| def test_ne(self): |
| def easy(x, y): |
| c = torch.ne(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) |
| a = torch.zeros(1024, dtype=torch.int32) |
| b = torch.ones(1024, dtype=torch.int32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.ones(1024), x.numpy()) |
| |
| def test_ge(self): |
| def easy(x, y): |
| c = torch.ge(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) |
| aa = np.empty([1024], dtype=np.int32) |
| aa.fill(5) |
| a = torch.from_numpy(aa) |
| b = torch.zeros(1024, dtype=torch.int32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.ones(1024), x.numpy()) |
| |
| def test_gt(self): |
| def easy(x, y): |
| c = torch.gt(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) |
| a = torch.ones(1024, dtype=torch.int32) |
| b = torch.zeros(1024, dtype=torch.int32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.ones(1024), x.numpy()) |
| |
| def test_le(self): |
| def easy(x, y): |
| c = torch.le(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) |
| aa = np.empty([1024], dtype=np.int32) |
| aa.fill(5) |
| a = torch.from_numpy(aa) |
| b = torch.zeros(1024, dtype=torch.int32) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.zeros(1024), x.numpy()) |
| |
| def test_lt(self): |
| def easy(x, y): |
| c = torch.lt(x, y) |
| return c |
| |
| for dev in self.devices: |
| traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) |
| a = torch.ones(1024, dtype=torch.int32, device=dev) |
| b = torch.zeros(1024, dtype=torch.int32, device=dev) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) |
| |
| @suppress_warnings |
| def test_min_max(self): |
| def test(x, y): |
| return torch.max(torch.min(x, y), torch.tensor([4.0])) |
| |
| traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024))) |
| a = 8.0 * torch.rand(1024) |
| b = 8.0 * torch.rand(1024) |
| np.testing.assert_allclose( |
| warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) |
| ) |
| self.assertLastGraphAllFused() |
| |
| def test_min_max_reduction(self): |
| def test(x): |
| return torch.min(x) + torch.max(x) |
| |
| traced = torch.jit.trace(test, (torch.zeros(1024))) |
| a = 8.0 * torch.rand(1024) |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) |
| self.assertLastGraphAllFused() |
| |
| def test_min_max_reduction2(self): |
| def test(x): |
| return x.min() + x.max() |
| |
| traced = torch.jit.trace(test, (torch.zeros(1024))) |
| a = 8.0 * torch.rand(1024) |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) |
| self.assertLastGraphAllFused() |
| |
| def test_min_max_reduction_dim1(self): |
| def test(x): |
| return torch.min(x, 1)[0] + torch.max(x, 1)[0] |
| |
| traced = torch.jit.trace(test, (torch.zeros(16, 16))) |
| a = 8.0 * torch.rand(16, 16) |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin( |
| a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) |
| self.assertLastGraphAllFused() |
| |
| def test_min_max_reduction_dim1_2(self): |
| def test(x): |
| return torch.min(x * x, 1) |
| |
| traced = torch.jit.trace(test, (torch.zeros(16, 16))) |
| a = 8.0 * torch.rand(16, 16) |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1)) |
| self.assertLastGraphAllFused() |
| |
| def test_clamp(self): |
| def test(x): |
| return torch.clamp(x + 3.0, 0.0, 6.0) |
| |
| for dev in self.devices: |
| traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) |
| a = 20.0 * torch.rand(1024, device=dev) - 10.0 |
| an = a.cpu().numpy() |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) |
| self.assertLastGraphAllFused() |
| |
| def test_relu(self): |
| def test(x): |
| return torch.clamp(F.relu(x), 0, 0.5) |
| |
| for dev in self.devices: |
| traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) |
| a = 20.0 * torch.rand(1024, device=dev) - 10.0 |
| an = a.cpu().numpy() |
| np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) |
| self.assertLastGraphAllFused() |
| |
| def test_reps(self): |
| def easy(x, y): |
| c = torch.add(x, y) |
| return c |
| |
| traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) |
| |
| for _ in range(32): |
| a = torch.ones(1024) |
| b = torch.zeros(1024) |
| x = warmup_and_run_forward(traced, a, b) |
| np.testing.assert_allclose(np.ones(1024), x.numpy()) |
| |
| def test_add_const_rhs(self): |
| def test(x): |
| return x + 3.0 |
| |
| traced = torch.jit.trace(test, torch.rand(4)) |
| x = torch.rand(4) |
| y = warmup_and_run_forward(traced, x) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) |
| |
| def test_int_output(self): |
| def test(x, y, z): |
| return x * y * z |
| |
| xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)] |
| x, y, z = xs |
| xn, yn, zn = (t.numpy() for t in xs) |
| traced = torch.jit.trace(test, (x, y, z)) |
| res = warmup_and_run_forward(traced, x, y, z) |
| self.assertLastGraphAllFused() |
| np.testing.assert_allclose(xn * yn * zn, res.numpy()) |
| |
| def test_binary_ops(self): |
| def test_atan2(x, y): |
| c = torch.atan2(torch.add(x, y), y) |
| return c |
| |
| def test_gt(x, y): |
| c = torch.gt(torch.add(x, y), y) |
| return c |
| |
| def test_ge(x, y): |
| c = torch.ge(torch.add(x, y), y) |
| return c |
| |
| def test_lt(x, y): |
| c = torch.lt(torch.add(x, y), y) |
| return c |
| |
| def test_le(x, y): |
| c = torch.le(torch.add(x, y), y) |
| return c |
| |
| def test_lerp(x, y): |
| c = torch.lerp(torch.add(x, 1), x, 2.0) |
| return c |
| |
| def test_mul(x, y): |
| c = torch.mul(torch.add(x, y), y) |
| return c |
| |
| def test_ne(x, y): |
| c = torch.ne(torch.add(x, y), y) |
| return c |
| |
| def test_div(x, y): |
| c = torch.div(torch.add(x, y), 2) |
| return c |
| |
| def test_eq(x, y): |
| c = torch.eq(torch.add(x, y), y) |
| return c |
| |
| def test_fmod(x, y): |
| c = torch.fmod(torch.add(x, y), 2) |
| return c |
| |
| def test_sub(x, y): |
| c = torch.sub(torch.add(x, y), x) |
| return c |
| |
| def test_remainder(x, y): |
| c = torch.remainder(torch.add(x, y), 3.0) |
| return c |
| |
| def test_pow(x, y): |
| c = torch.pow(torch.add(x, y), 2.0) |
| return c |
| |
| def test_type_as(x, y): |
| return x.type_as(torch.add(x, y)) |
| |
| cmp_fns = { |
| test_gt, |
| test_ge, |
| test_lt, |
| test_le, |
| test_ne, |
| test_eq |
| } |
| |
| non_cmp_fns = { |
| test_atan2, |
| test_lerp, |
| test_mul, |
| test_div, |
| test_fmod, |
| test_sub, |
| test_remainder, |
| test_pow, |
| test_type_as, |
| } |
| |
| all_test_fns = cmp_fns.union(non_cmp_fns) |
| fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes) |
| for torch_fn, dev, data_type in fn_dev_dtype: |
| if torch_fn is test_lerp and data_type is torch.bfloat16: |
| continue |
| rand_a = torch.rand(1024, dtype=data_type, device=dev) |
| rand_b = torch.rand(1024, dtype=data_type, device=dev) |
| in1 = 20 * torch.rand(1024, dtype=data_type, device=dev) |
| in2 = 20 * torch.rand(1024, dtype=data_type, device=dev) |
| traced = torch.jit.trace(torch_fn, (in1, in2)) |
| x = warmup_and_run_forward(traced, rand_a, rand_b) |
| self.assertLastGraphAllFused() |
| |
| _atol = 2e-3 |
| _rtol = 1e-5 |
| if data_type is torch.bfloat16: |
| # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion. |
| # Take d = a + b - c as an example, the aten logic is as follows at |
| # operator level: |
| # tmp = to_bf16(to_fp32(a) + to_fp32(b)) |
| # d = to_bf16(to_fp32(tmp) + to_fp32(c)) |
| # But NNC could fuse the compression and remove the redudant conversions. |
| # The final statement is as follows |
| # d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c)) |
| # Hence, we simulate NNC computation by feeding fp32 tensors and converting |
| # the result tensor back to bf16. The simulation could avoid the numeric |
| # deviation to simplify the result comprasion |
| y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) |
| if torch_fn not in cmp_fns: |
| y = y.bfloat16() |
| _atol = 2e-2 |
| else: |
| y = torch_fn(rand_a, rand_b) |
| self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) |
| |
| def test_unary_ops(self): |
| def test_cast_float(x, y): |
| c = torch.ops.aten._cast_Float(torch.add(x, y)) |
| return c |
| |
| def test_round(x, y): |
| c = torch.round(torch.add(x, y)) |
| return c |
| |
| def test_sin(x, y): |
| c = torch.sin(torch.add(x, y)) |
| return c |
| |
| def test_asin(x, y): |
| c = torch.asin(torch.add(x, y)) |
| return c |
| |
| def test_sinh(x, y): |
| c = torch.sinh(torch.add(x, y)) |
| return c |
| |
| def test_cos(x, y): |
| c = torch.cos(torch.add(x, y)) |
| return c |
| |
| def test_acos(x, y): |
| c = torch.acos(torch.add(x, y)) |
| return c |
| |
| def test_cosh(x, y): |
| c = torch.cosh(torch.add(x, y)) |
| return c |
| |
| def test_tan(x, y): |
| c = torch.tan(torch.add(x, y)) |
| return c |
| |
| def test_atan(x, y): |
| c = torch.atan(torch.add(x, y)) |
| return c |
| |
| def test_tanh(x, y): |
| c = torch.tanh(torch.add(x, y)) |
| return c |
| |
| def test_sqrt(x, y): |
| c = torch.sqrt(torch.add(x, y)) |
| return c |
| |
| def test_rsqrt(x, y): |
| c = torch.rsqrt(torch.add(x, y)) |
| return c |
| |
| def test_floor(x, y): |
| c = torch.floor(torch.add(x, y)) |
| return c |
| |
| def test_ceil(x, y): |
| c = torch.ceil(torch.add(x, y)) |
| return c |
| |
| def test_trunc(x, y): |
| c = torch.trunc(torch.add(x, y)) |
| return c |
| |
| def test_abs(x, y): |
| c = torch.abs(torch.add(x, y)) |
| return c |
| |
| def test_log(x, y): |
| c = torch.log(torch.add(x, y)) |
| return c |
| |
| def test_log2(x, y): |
| c = torch.log2(torch.add(x, y)) |
| return c |
| |
| def test_log10(x, y): |
| c = torch.log10(torch.add(x, y)) |
| return c |
| |
| def test_log1p(x, y): |
| c = torch.log1p(torch.add(x, y)) |
| return c |
| |
| def test_rqrt(x, y): |
| c = torch.rsqrt(torch.add(x, y)) |
| return c |
| |
| def test_erf(x, y): |
| c = torch.erf(torch.add(x, y)) |
| return c |
| |
| def test_exp(x, y): |
| c = torch.exp(torch.add(x, y)) |
| return c |
| |
| def test_expm1(x, y): |
| c = torch.expm1(torch.add(x, y)) |
| return c |
| |
| def test_erfc(x, y): |
| c = torch.erfc(torch.add(x, y)) |
| return c |
| |
| def test_frac(x, y): |
| c = torch.frac(torch.add(x, y)) |
| return c |
| |
| def test_lgamma(x, y): |
| c = torch.lgamma(torch.add(x, y)) |
| return c |
| |
| def test_sigmoid(x, y): |
| c = torch.sigmoid(torch.add(x, y)) |
| return c |
| |
| def test_reciprocal(x, y): |
| c = torch.reciprocal(torch.add(x, y)) |
| return c |
| |
| def test_neg(x, y): |
| c = torch.neg(torch.add(x, y)) |
| return c |
| |
| def test_relu(x, y): |
| c = torch.relu(torch.add(x, y)) |
| return c |
| |
| def test_hardtanh(x, y): |
| c = F.hardtanh(torch.add(x, y), -1.0, 1.0) |
| return c |
| |
| def test_threshold(x, y): |
| c = F.threshold(torch.add(x, y), 0.5, 10) |
| return c |
| |
| gpu_only_fns = { |
| test_erf, |
| test_erfc |
| } |
| fns = { |
| test_round, |
| test_sin, |
| test_asin, |
| test_sinh, |
| test_cos, |
| test_acos, |
| test_cosh, |
| test_tan, |
| test_atan, |
| test_sqrt, |
| test_floor, |
| test_ceil, |
| test_trunc, |
| test_abs, |
| test_log, |
| test_log2, |
| test_log10, |
| test_log1p, |
| test_rsqrt, |
| test_exp, |
| test_expm1, |
| test_frac, |
| test_lgamma, |
| test_reciprocal, |
| test_neg, |
| test_threshold, |
| test_relu, |
| test_tanh, |
| test_hardtanh, |
| test_sigmoid, |
| } |
| fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes) |
| |
| torch.manual_seed(0) |
| for torch_fn, dev, data_type in fn_dev_dtype: |
| if torch_fn == test_lgamma and dev == "cuda": |
| # lgamma_cuda does not support BF16 |
| continue |
| rand_a = torch.rand(1024, dtype=data_type, device=dev) |
| rand_b = torch.rand(1024, dtype=data_type, device=dev) |
| |
| ins = 20 * torch.rand(1024, dtype=data_type, device=dev) |
| cc = np.empty([1024], dtype=np.float32) |
| cc.fill(np.nan) |
| nans = torch.from_numpy(cc).to(dev) |
| traced = torch.jit.trace(torch_fn, (ins, ins)) |
| x = warmup_and_run_forward(traced, rand_a, rand_b) |
| self.assertLastGraphAllFused() |
| |
| _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3 |
| _rtol = 1e-5 |
| if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns: |
| y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) |
| y = y.bfloat16() |
| else: |
| y = torch_fn(rand_a, rand_b) |
| |
| self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) |
| # nans |
| # TODO: reenable. Currently all of the tests fail |
| # traced = torch.jit.trace(torch_fn, (ins, ins)) |
| # x = warmup_and_run_forward(traced, rand_a, rand_b) |
| # y = torch_fn(nans, rand_b) |
| # try: |
| # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) |
| # print("Succeeded on dev=", dev, "function=", torch_fn) |
| # except AssertionError: |
| # # Print extra info before exiting: |
| # print("Failed on dev=", dev, "function=", torch_fn) |
| # # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) |
| |
| |
| def test_round_2(self): |
| def round(x): |
| return torch.round(x) |
| |
| for data_type in [torch.float32, torch.double]: |
| a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type) |
| traced = torch.jit.trace(round, (a)) |
| x = warmup_and_run_forward(traced, a) |
| self.assertLastGraphAllFused() |
| y = round(x) |
| self.assertEqual(x, y) |
| |
| def test_rand_like(self): |
| N = 1 << 16 |
| |
| def run_rand_like(x, y): |
| return torch.rand_like(torch.add(x, y)) |
| |
| for device in self.devices: |
| x = torch.rand(N, device=device) |
| traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) |
| |
| for data_type in self.dtypes: |
| _x = x.to(dtype=data_type) |
| x_v = warmup_and_run_forward(traced, _x, _x) |
| self.assertLastGraphAllFused() |
| |
| x_np = x.cpu().numpy() |
| x1_mean = np.mean(x_np) |
| x2_mean = np.mean(x_np ** 2) |
| x3_mean = np.mean(x_np ** 3) |
| np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) |
| np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) |
| np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) |
| |
| def test_nans(self): |
| def test_max(x, y): |
| return torch.max(2 * x, 2 * y) |
| |
| def test_min(x, y): |
| return torch.min(2 * x, 2 * y) |
| |
| tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1))) |
| tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1))) |
| |
| for data_type in self.dtypes: |
| x = torch.tensor([np.nan]).to(dtype=data_type) |
| y = torch.tensor([1.0]).to(dtype=data_type) |
| |
| assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item()) |
| assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item()) |
| self.assertLastGraphAllFused() |
| assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item()) |
| assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item()) |
| self.assertLastGraphAllFused() |
| |
| def test_double_intrinsics(self): |
| def do_pow(x): |
| return torch.pow(x, 7) |
| |
| for device in self.devices: |
| x = torch.rand(10, dtype=torch.double, device=device) |
| traced = torch.jit.trace(do_pow, (x)) |
| x = warmup_and_run_forward(traced, x) |
| self.assertLastGraphAllFused() |
| |
| def test_remainder(self): |
| def run_remainder(x, y): |
| c = torch.remainder(torch.add(x, y), x) |
| return c |
| |
| for data_type in self.dtypes: |
| a = torch.rand(1024, dtype=data_type) |
| b = torch.rand(1024, dtype=data_type) |
| zeros = torch.zeros(1024, dtype=data_type) |
| cc = np.array(1024, dtype=float) |
| cc.fill(np.nan) |
| nans = torch.from_numpy(cc).to(dtype=data_type) |
| |
| # random floats |
| zeros1 = torch.zeros(1024, dtype=data_type) |
| zeros2 = torch.zeros(1024, dtype=data_type) |
| |
| traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| y = run_remainder(a, b) |
| if data_type is torch.bfloat16: |
| self.assertEqual(x, y, atol=4e-3, rtol=2e-3) |
| else: |
| self.assertEqual(x, y) |
| |
| # div by 0 |
| traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) |
| x = warmup_and_run_forward(traced, zeros, a) |
| self.assertLastGraphAllFused() |
| y = run_remainder(zeros, a) |
| self.assertEqual(x, y) |
| |
| # numerators and denominatos are nan |
| traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) |
| x = warmup_and_run_forward(traced, nans, a) |
| self.assertLastGraphAllFused() |
| y = run_remainder(nans, a) |
| self.assertEqual(x, y) |
| |
| def test_multioutput(self): |
| def easy(x): |
| b = x + 1 |
| c = b + b |
| return (b, c) |
| |
| traced = torch.jit.trace(easy, (torch.zeros(1024))) |
| |
| a = torch.zeros(1024) |
| b, c = warmup_and_run_forward(traced, a) |
| self.assertLastGraphAllFused() |
| bp = a.numpy() + 1 |
| cp = bp + bp |
| np.testing.assert_allclose(b.numpy(), bp) |
| np.testing.assert_allclose(c.numpy(), cp) |
| |
| def test_chunk(self): |
| def easy(x): |
| y = x + 1 |
| aaa, bbb = torch.chunk(y, 2) |
| return aaa + bbb |
| |
| for data_type in self.dtypes: |
| trace_input = torch.zeros(1024, 1024, dtype=data_type) |
| traced = torch.jit.trace(easy, (trace_input)) |
| |
| a = torch.zeros(32, 32, dtype=data_type) |
| x = warmup_and_run_forward(traced, a) |
| self.assertLastGraphAllFused() |
| npr = a.float().numpy() |
| npr2 = npr + 1 |
| npr_a, npr_b = np.array_split(npr2, 2) |
| np.testing.assert_allclose(npr_a + npr_b, x.float().numpy()) |
| |
| def test_cat(self): |
| for device in self.devices: |
| _dim = 1 |
| |
| def foo(*args): |
| args_2 = [v + i for i, v in enumerate(args)] |
| v = torch.cat(args_2, dim=_dim) |
| return v * v |
| |
| for data_type in self.dtypes: |
| M = 16 |
| Ns = [128, 16, 1] |
| values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy()) |
| |
| # Test channels-last |
| for _cur_dim in range(4): |
| _dim = _cur_dim |
| values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| self.assertEqual(ref, x) |
| |
| # This test checks that we correctly handle fusion group with just aten::cat in it. |
| # Note that the test only makes sense with min_fusion_group=1, otherwise no |
| # fusion groups would be formed at all. |
| # TODO: Fix and re-enable the test. |
| @unittest.skip("cat is broken with fusion group inlining disabled") |
| def test_cat_only(self): |
| for device in self.devices: |
| def foo(*args): |
| args_2 = [v + i for i, v in enumerate(args)] |
| v = torch.cat(args_2, dim=1) |
| return v |
| |
| M = 16 |
| Ns = [128, 16, 1] |
| values = [torch.zeros(M, N, device=device) for N in Ns] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| def test_cat_negative_dim(self): |
| for device in self.devices: |
| def foo(*args): |
| v = torch.cat(args, dim=-1) |
| return v * v |
| |
| M = 16 |
| Ns = [128, 16, 1] |
| values = [torch.randn(M, N, device=device) for N in Ns] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| def test_cat_promote_inputs(self): |
| for device in self.devices: |
| def foo(*args): |
| v = torch.cat(args, dim=1) |
| return v * v |
| |
| M = 16 |
| Ns = [128, 16, 1] |
| dtypes = [torch.half, torch.float32, torch.double] |
| values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| def test_cat_empty_tensors(self): |
| for device in self.devices: |
| def foo(*args): |
| v = torch.cat(args, dim=1) |
| return v * v |
| |
| M = 16 |
| Ns = [128, 16, 1] |
| empty = torch.tensor([], device=device, dtype=torch.double) |
| values = [empty] + [torch.randn(M, N, device=device) for N in Ns] |
| traced = torch.jit.trace(foo, values) |
| |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| # now test with only empty tensors |
| values = [empty for i in range(3)] |
| traced = torch.jit.trace(foo, values) |
| x = warmup_and_run_forward(traced, *values) |
| self.assertLastGraphAllFused() |
| ref = foo(*values) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| def test_cat_with_constant_dim(self): |
| for device in self.devices: |
| def foo(*args): |
| v1 = torch.cat(args, dim=1) |
| v2 = torch.cat([v1], dim=1) |
| return v2 * v2 |
| |
| empty = torch.tensor([], device=device, dtype=torch.float32) |
| inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)] |
| traced = torch.jit.trace(foo, inputs) |
| |
| x = warmup_and_run_forward(traced, *inputs) |
| self.assertLastGraphAllFused() |
| ref = foo(*inputs) |
| np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) |
| |
| def test_scalar(self): |
| @torch.jit.script |
| def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: |
| return torch.add(torch.add(x, y, alpha=a), z, alpha=b) |
| |
| @torch.jit.script |
| def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor: |
| return torch.add(torch.add(x, y, alpha=a), z, alpha=b) |
| |
| for test in (test_float, test_int): |
| for data_type in self.dtypes: |
| x, y, z = (torch.rand(4, dtype=data_type) for i in range(3)) |
| a, b = 1, 2 |
| test(x, y, z, a, b) |
| r = test(x, y, z, a, b) |
| self.assertEqual(r, x + y * a + z * b) |
| |
| def test_loop(self): |
| @torch.jit.script |
| def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: |
| b = y |
| for i in range(0, z): |
| a = x + y |
| b = b + y |
| return b |
| |
| x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4) |
| test(x, y, z) |
| r = test(x, y, z) |
| |
| def test_slice(self): |
| def easy(x, y): |
| a = x[0:512:2] |
| b = y[0:512:2] |
| return a + b |
| |
| traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) |
| |
| a = torch.ones(1024, 1024) |
| x = traced(a, a) |
| npr = a[0:512:2] |
| npr = npr + npr |
| np.testing.assert_allclose(npr.numpy(), x.numpy()) |
| |
| def test_unsqueeze(self, N=256): |
| def easy(x, y): |
| a = torch.unsqueeze(x, 0) |
| b = torch.unsqueeze(y, 0) |
| return a + b |
| |
| traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N))) |
| |
| a = torch.rand(N, N) |
| x = traced(a, a) |
| npr = np.expand_dims(a, 0) |
| npr = npr + npr |
| np.testing.assert_allclose(npr, x.numpy()) |
| |
| def _test_softmax(self, device): |
| def test_softmax(x, y): |
| a = F.softmax(x, dim=0, dtype=torch.float32) |
| b = F.softmax(y, dim=0, dtype=torch.float32) |
| c = F.softmax(x, dim=1, dtype=torch.float32) |
| d = F.softmax(y, dim=1, dtype=torch.float32) |
| return a + b + c + d |
| |
| def test_softmax_neg_index(x, y): |
| a = F.softmax(x, dim=-2, dtype=torch.float32) |
| b = F.softmax(y, dim=-2, dtype=torch.float32) |
| c = F.softmax(x, dim=-1, dtype=torch.float32) |
| d = F.softmax(y, dim=-1, dtype=torch.float32) |
| return a + b + c + d |
| |
| def test_log_softmax(x, y): |
| a = F.log_softmax(x, dim=0, dtype=torch.float32) |
| b = F.log_softmax(y, dim=0, dtype=torch.float32) |
| c = F.log_softmax(x, dim=1, dtype=torch.float32) |
| d = F.log_softmax(y, dim=1, dtype=torch.float32) |
| return a + b + c + d |
| |
| for test in (test_softmax, test_log_softmax, test_softmax_neg_index): |
| for data_type in self.dtypes: |
| old = torch._C._jit_set_texpr_reductions_enabled(True) |
| traced_input = torch.randn(2, 3, dtype=data_type, device=device) |
| traced = torch.jit.trace(test, (traced_input, traced_input)) |
| inp = torch.randn(2, 3, dtype=data_type, device=device) |
| res = traced(inp, inp) |
| # Use eager mode as reference. |
| ref = test(inp, inp) |
| np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) |
| torch._C._jit_set_texpr_reductions_enabled(old) |
| |
| def test_softmax_cpu(self): |
| self._test_softmax('cpu') |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") |
| @unittest.skip("global allocs are not supported yet.") |
| def test_softmax_cuda(self): |
| self._test_softmax('cuda') |
| |
| def test_half_gelu(self): |
| devices = ["cuda"] if torch.cuda.is_available() else [] |
| |
| @torch.jit.script |
| def bias_gelu(bias, y): |
| x = bias + y |
| return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) |
| |
| for device in devices: |
| a = torch.rand(1024, dtype=torch.half, device=device) |
| b = torch.rand(1024, dtype=torch.half, device=device) |
| traced = torch.jit.trace(bias_gelu, (a, b)) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| |
| def test_half_bn_relu(self): |
| devices = ["cuda"] if torch.cuda.is_available() else [] |
| |
| def foo(a, b, c): |
| y = torch.nn.functional.batch_norm(a, b, c) |
| z = y.relu() |
| return z |
| |
| for device in devices: |
| a = torch.rand(16, 16, dtype=torch.half, device=device) |
| b = torch.rand(16, dtype=torch.half, device=device) |
| c = torch.rand(16, dtype=torch.half, device=device) |
| traced = torch.jit.trace(foo, (a, b, c)) |
| print(traced.graph) |
| x = warmup_and_run_forward(traced, a, b, c) |
| self.assertLastGraphAllFused() |
| |
| def test_exp_pow(self): |
| @torch.jit.script |
| def do_exp(x, y, z): |
| return ((x * y) * 2) * torch.pow(z, 2) |
| |
| for device in self.devices: |
| x = torch.rand(10, dtype=torch.double, device=device) |
| y = torch.rand(10, dtype=torch.double, device=device) |
| z = torch.rand(10, dtype=torch.double, device=device) |
| traced = torch.jit.trace(do_exp, (x, y, z)) |
| x = warmup_and_run_forward(traced, x, y, z) |
| self.assertLastGraphAllFused() |
| |
| def test_sin_pow(self): |
| def test(x): |
| return torch.sin(torch.pow(x, 0)) |
| |
| for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]): |
| x = torch.rand(shape, dtype=data_type) |
| scripted = torch.jit.script(test) |
| out = warmup_and_run_forward(scripted, x) |
| self.assertLastGraphAllFused() |
| self.assertEqual(out, test(x)) |
| |
| def test_transpose(self): |
| @torch.jit.script |
| def test(x, y, z): |
| return x.transpose(0, 1) + y + z |
| x = torch.rand(4, 5, 2, 3) |
| y = torch.rand(5, 4, 2, 3) |
| z = torch.rand(5, 4, 2, 3) |
| ref = test(x, y, z) |
| res = test(x, y, z) |
| np.testing.assert_allclose(ref.numpy(), res.numpy()) |
| |
| def test_sliced_stride(self): |
| @torch.jit.script |
| def test(x, y, z): |
| return x + y + z |
| x = torch.rand(16, 4, 2, 3)[::2] |
| y = torch.rand(8, 4, 2, 3) |
| z = torch.rand(8, 4, 2, 3) |
| ref = test(x, y, z) |
| res = test(x, y, z) |
| np.testing.assert_allclose(ref.numpy(), res.numpy()) |
| |
| @unittest.skip("dynamic shapes are not quite there yet") |
| @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") |
| def test_dynamic_shape(self): |
| with num_profiled_runs(2): |
| @torch.jit.script |
| def test(x, y, z): |
| return x * y * z |
| x, y, z = (torch.rand(4, 8).cuda() for _ in range(3)) |
| ref = test(x, y, z) |
| _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) |
| res = test(x, y, z) |
| np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) |
| |
| # A wild broadcast appears. |
| x = torch.rand(4, 8).cuda() |
| y = torch.rand(1, 8).cuda() |
| z = torch.rand(4, 1).cuda() |
| res = test(x, y, z) |
| xn, yn, zn = (t.cpu().numpy() for t in (x, y, z)) |
| np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) |
| |
| # Mismatched shapes shouldn't reach codegen. |
| x = torch.rand(4, 8).cuda() |
| y = torch.rand(4, 8).cuda() |
| z = torch.rand(5, 8).cuda() |
| try: |
| res = test(x, y, z) |
| except RuntimeError as e: |
| assert "The size of tensor a (4) must match" in e.args[0] |
| |
| # Changing a static dimension fails guards. |
| # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] |
| # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] |
| # res = test(x, y, z) |
| # print(test.graph_for(x, y, z)) |
| # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") |
| def test_guard_fails(self): |
| @torch.jit.script |
| def test(x, y, z): |
| return x * y * z |
| r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) |
| r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) |
| r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) |
| r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) |
| |
| def test_bitwise_ops(self): |
| def run_and(x, y): |
| return x & (x & y) |
| |
| def run_or(x, y): |
| return x & (x | y) |
| |
| def run_xor(x, y): |
| return x ^ (x ^ y) |
| |
| def run_lshift(x, y): |
| return x & (x << y) |
| |
| def run_rshift(x, y): |
| return x & (x >> y) |
| |
| fns = {run_and, run_or, run_xor, run_lshift, run_rshift} |
| |
| for device in self.devices: |
| for fn in fns: |
| a = torch.ones(128, dtype=torch.int32, device=device) |
| b = torch.zeros(128, dtype=torch.int32, device=device) |
| inp = torch.ones(128, dtype=torch.int32, device=device) |
| traced = torch.jit.trace(fn, (inp, inp)) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| y = fn(a, b) |
| np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) |
| |
| def test_where(self): |
| def run_where(x, y): |
| return torch.where(torch.gt(x, y), x, y) |
| |
| for data_type in self.dtypes: |
| a = torch.rand(1024, dtype=data_type) |
| b = torch.rand(1024, dtype=data_type) |
| zeros = torch.zeros(1024, dtype=data_type) |
| traced = torch.jit.trace(run_where, (zeros, zeros)) |
| x = warmup_and_run_forward(traced, a, b) |
| self.assertLastGraphAllFused() |
| y = run_where(a, b) |
| np.testing.assert_allclose(x.float().numpy(), y.float().numpy()) |
| |
| def test_multi_rand(self): |
| for device in self.devices: |
| def test(x): |
| y = torch.rand_like(x) |
| return (x + y) - (y - x) |
| |
| _atol = 2e-3 |
| _rtol = 1e-5 |
| for data_type in self.dtypes: |
| if data_type is torch.bfloat16: |
| _atol = 2e-2 |
| a = torch.rand(4, dtype=data_type, device=device) |
| scripted = torch.jit.script(test) |
| out = warmup_and_run_forward(scripted, a) |
| self.assertLastGraphAllFused() |
| assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol) |
| |
| def test_mask(self): |
| def test(x): |
| return x.unsqueeze(1) == 0 |
| |
| for d in self.devices: |
| for data_type in self.dtypes: |
| x = torch.rand(4, dtype=data_type, device=d) > 0.5 |
| scripted = torch.jit.script(test) |
| out = warmup_and_run_forward(scripted, x) |
| self.assertLastGraphAllFused() |
| assert torch.equal(out, test(x)) |
| |
| def test_simple_add(self): |
| val = torch._C._jit_get_te_generate_block_code() |
| torch._C._jit_set_te_generate_block_code(True) |
| fall_bk = torch._C._jit_texpr_fallback_allowed() |
| torch._C._jit_texpr_set_fallback_allowed(True) |
| |
| def simple(a, b): |
| return torch.add(a, b) |
| |
| a = torch.ones(256, 256) |
| b = torch.ones(256, 256) |
| traced = torch.jit.trace(simple, |
| (torch.ones(256, 256), torch.ones(256, 256))) |
| f = traced(a, b) |
| f_test = np.full((256, 256), 2, dtype=float) |
| np.testing.assert_allclose(f.numpy(), f_test) |
| torch._C._jit_set_te_generate_block_code(val) |
| torch._C._jit_texpr_set_fallback_allowed(fall_bk) |
| |
| def test_strided_output_preserved(self): |
| def foo(a, b): |
| return a + b - a |
| |
| # smaller, easier to debug example |
| x = torch.arange(6) |
| x = torch.as_strided(x, (2, 3), (1, 2)) |
| total = 0 |
| for i in range(2): |
| for j in range(3): |
| x[i, j] = total |
| total += 1 |
| foo_script = torch.jit.script(foo) |
| foo_script(x, x) |
| foo_script(x, x) |
| out_s = foo_script(x, x) |
| out_eager = foo(x, x) |
| self.assertEqual(out_s, out_eager) |
| self.assertEqual(out_s.stride(), out_eager.stride()) |
| self.assertLastGraphAllFused() |
| |
| # more dims |
| N, C, H, W, = 2, 3, 4, 5 |
| x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) |
| foo_script = torch.jit.script(foo) |
| foo_script(x, x) |
| foo_script(x, x) |
| out_s = foo_script(x, x) |
| out_eager = foo(x, x) |
| self.assertEqual(out_s, out_eager) |
| self.assertEqual(out_s.stride(), out_eager.stride()) |
| self.assertLastGraphAllFused() |
| |
| def test_alias_analysis_module(self): |
| class AliasModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| torch.manual_seed(1337) |
| self.a = torch.randn(128, 128) |
| self.b = torch.randn(128, 128) |
| self.c = torch.randn(128, 128) |
| |
| def forward(self, x, y, z): |
| z = z + self.a |
| self.b.add_(y) |
| w = z + self.a |
| z = w + x |
| return z |
| x = torch.randn(128, 128) |
| |
| def getModule(script): |
| am = AliasModule() |
| if script: |
| return torch.jit.script(am) |
| return am |
| |
| am = getModule(False) |
| am_s = getModule(True) |
| ref = am(x, x, x) |
| test = am_s(x, x, x) |
| torch.testing.assert_close(ref, test) |
| |
| # Now do the aliasing |
| am.a = am.b |
| ref = am(x, x, x) |
| |
| am_s.a = am_s.b |
| test = am_s(x, x, x) |
| |
| torch.testing.assert_close(ref, test) |
| |
| def test_alias_analysis_inputs(self): |
| class AliasModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| torch.manual_seed(1337) |
| self.a = torch.randn(128, 128) |
| self.b = torch.randn(128, 128) |
| self.c = torch.randn(128, 128) |
| |
| def forward(self, x, y, z): |
| x.add_(y) |
| w = z + self.a |
| z = w + x |
| return z |
| |
| def getModule(script): |
| am = AliasModule() |
| if script: |
| return torch.jit.script(am) |
| return am |
| am = getModule(False) |
| am_s = getModule(True) |
| |
| torch.manual_seed(1337) |
| x = torch.randn(128, 128) |
| ref = am(x, x, x) |
| |
| torch.manual_seed(1337) |
| x = torch.randn(128, 128) |
| test = am_s(x, x, x) |
| |
| torch.testing.assert_close(ref, test) |
| |
| def test_alias_analysis_input_and_module(self): |
| class AliasModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| torch.manual_seed(1337) |
| self.a = torch.randn(128, 128) |
| self.b = torch.randn(128, 128) |
| self.c = torch.randn(128, 128) |
| |
| def forward(self, x, y, z): |
| x.add_(y) |
| w = z + self.b |
| z = w + x |
| return z |
| |
| def getModule(script): |
| am = AliasModule() |
| if script: |
| return torch.jit.script(am) |
| return am |
| am = getModule(False) |
| am_s = getModule(True) |
| |
| torch.manual_seed(1337) |
| x = torch.randn(128, 128) |
| am.b = x |
| ref = am(x, x, x) |
| |
| torch.manual_seed(1337) |
| x = torch.randn(128, 128) |
| am_s.b = x |
| test = am_s(x, x, x) |
| |
| torch.testing.assert_close(ref, test) |
| |
| def test_multiple_outputs(self): |
| for device in self.devices: |
| # A bug reported internally similar to the one reported in #48533 |
| def foo(a, b, c): |
| t_next = c + 1 |
| t5 = t_next * b |
| t6 = torch.unsqueeze(t_next, 1) |
| t7 = a * t6 |
| return (t7, t5, t_next) |
| |
| for data_type in self.dtypes: |
| a = torch.rand(20, 20, dtype=data_type, device=device) |
| b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29]) |
| c = torch.ones(20, dtype=torch.int64, device=device) |
| traced = torch.jit.trace(foo, (a, b, c)) |
| ref = foo(a, b, c) |
| exp = traced(a, b, c) |
| exp = traced(a, b, c) |
| self.assertEqual(ref, exp) |
| |
| def test_propagated_mem_layout(self): |
| def foo(a, b, c): |
| t_next = c + 1 |
| t5 = t_next * b |
| t7 = a * t5 |
| return t7 |
| |
| def foo_multi_outputs(a, b, c): |
| t_next = c + 1 |
| t5 = b * t_next |
| t7 = a * t5 |
| return (t7, t5, t_next) |
| |
| def foo_multi_outputs_i_nhwc_o_nchw(a, b, c): |
| t_next = c + 1 |
| t5 = b * t_next |
| t7 = a * t5 |
| t8 = t7.to(memory_format=torch.contiguous_format) |
| return (t8, t7, t5, t_next) |
| |
| def run_foo_case(foo, a, b, c): |
| traced_contiguous = torch.jit.trace(foo, (a, b, c)) |
| ref = foo(a, b, c) |
| exp = traced_contiguous(a, b, c) |
| exp = traced_contiguous(a, b, c) |
| self.assertEqual(ref, exp) |
| |
| mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3)) |
| shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)] |
| permutes = [(0, 3, 2, 1), (0, 3, 1, 2)] |
| funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw] |
| configs = itertools.product(funcs, shapes, mem_layouts, permutes) |
| for strategy in ["STATIC", "DYNAMIC"]: |
| old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)]) |
| for _func, _shape, _mem_layouts, _permute in configs: |
| a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0]) |
| b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1]) |
| c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2]) |
| run_foo_case(_func, a, b, c) |
| |
| a = a.permute(dims=_permute) |
| b = b.permute(dims=_permute) |
| c = c.permute(dims=_permute) |
| run_foo_case(_func, a, b, c) |
| |
| torch.jit.set_fusion_strategy(old_strategy) |
| |
| if __name__ == '__main__': |
| run_tests() |