| # Owner(s): ["module: primTorch"] |
| |
| from functools import partial |
| from itertools import product |
| import unittest |
| |
| import torch |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCUDA, |
| skipCUDAIfRocm, |
| dtypes, |
| ) |
| from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input |
| import torch._prims as prims |
| from torch._prims.executor import make_traced |
| import torch._refs as refs |
| |
| |
| if TEST_SCIPY: |
| import scipy.special |
| |
| |
| class TestPrims(TestCase): |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_broadcast_in_dim(self, device, dtype): |
| # nvfuser is not currently capable of realizing a broadcasted tensor |
| # when the broadcast is the only operation. Another op is needed. |
| def _wrapper(a, b, broadcast_dimensions): |
| a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) |
| return prims.add(a_bc, b) |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'nvfuser'): |
| fn = partial(traced, executor=executor) |
| # Same shape |
| shape = (5, 5) |
| a = make_arg(shape) |
| b = make_arg(shape, low=0.0, high=0.0) |
| result = fn(a, b, (0, 1)) |
| |
| self.assertEqual(result.shape, a.shape) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(a, result) |
| |
| # Error input: reordering dims |
| with self.assertRaises(Exception): |
| result = fn(a, b, (1, 0)) |
| |
| # Adding outermost dimensions |
| a = make_arg((5, 5)) |
| b = make_arg((3, 3, 5, 5), low=0.0, high=0.0) |
| result = fn(a, b, (2, 3)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.broadcast_to(b.shape), result) |
| |
| # Expands |
| a = make_arg((1, 5, 1)) |
| b = make_arg((3, 5, 7), low=0.0, high=0.0) |
| result = fn(a, b, (0, 1, 2)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.expand_as(result), result) |
| |
| # Unsqueezes |
| a = make_arg((1, 2, 3)) |
| b = make_arg((1, 2, 1, 3), low=0.0, high=0.0) |
| result = fn(a, b, (0, 1, 3)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.unsqueeze(2), result) |
| |
| # FIXME: This test exposes an issue in nvfuser |
| # Adds outermost, expands, and unsqueezes |
| """ |
| a = make_arg((1, 2, 3)) |
| b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0) |
| result = fn(a, b, (1, 3, 4)) |
| |
| self.assertEqual(result.shape, b.shape) |
| a.unsqueeze_(3) |
| a.unsqueeze_(1) |
| a.unsqueeze_(0) |
| self.assertEqual(a.expand_as(result), result) |
| """ |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_broadcast_in_dim_sum(self, device, dtype): |
| def _wrapper(a): |
| a_sum = prims.sum(a, [0, 1]) |
| a_bc = prims.broadcast_in_dim(a_sum, [], []) |
| return a_bc |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'nvfuser'): |
| fn = partial(traced, executor=executor) |
| shape = (5, 5) |
| a = make_arg(shape) |
| result = fn(a) |
| |
| self.assertEqual(result.shape, ()) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(_wrapper(a), result) |
| |
| @unittest.skipIf(not TEST_SCIPY, "SciPy not found") |
| @dtypes(torch.float64, torch.long) |
| def test_cbrt_prim(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] |
| shapes = [(), (0,), (1,), (5,)] |
| |
| try: |
| # Sets the default dtype to NumPy's default dtype of double |
| cur_default = torch.get_default_dtype() |
| torch.set_default_dtype(torch.double) |
| |
| # Tested here, as this OP is not currently exposed or tested in ATen |
| for b, s in product(batches, shapes): |
| x = make_arg(b + s) |
| y = prims.cbrt(x) |
| |
| x_np = x.cpu().numpy() |
| y_np = scipy.special.cbrt(x_np) |
| |
| self.assertEqual(y, y_np, exact_device=False) |
| finally: |
| torch.set_default_dtype(cur_default) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_impl_is_used(self, device): |
| # This test is to ensure that when the nvfuser implementation exists it is used |
| # Assuming one-to-one mapping between prims and nvfuser implementations |
| # This test is not intended to test the correctness of the nvfuser implementation |
| from torch._C._nvfuser import FusionDefinition as fd |
| |
| prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.Ops)) |
| ops_without_nvfuser_impl = { |
| name |
| for name in prim_nvfuser_ops |
| if getattr(torch.ops.prims, name).default.impl_nvfuser is None |
| } |
| assert ( |
| len(ops_without_nvfuser_impl) == 0 |
| ), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ", |
| "while there exists nvfuser implementations for them.") |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_executor_cached_noncontiguous(self, device): |
| # This test is to ensure that nvfuser computes correct results for noncontiguous tensors |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a): |
| return torch.sigmoid(a) |
| |
| with TorchRefsMode.push(): |
| gm = make_fx(func)(a) |
| |
| # First run to create the cache |
| execute(gm, a, executor="nvfuser") |
| |
| # a.mT is noncontiguous, but it shouldn't affect correctness |
| expected = execute(gm, a.mT, executor="aten") |
| actual = execute(gm, a.mT, executor="nvfuser") |
| self.assertEqual(expected, actual) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| @parametrize("correction", [0, 1]) |
| def test_var(self, device, dtype, correction): |
| def _wrapper(a): |
| return prims.var(a, [0, 1], correction=correction) |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'nvfuser'): |
| fn = partial(traced, executor=executor) |
| shape = (5, 5) |
| a = make_arg(shape) |
| result = fn(a) |
| |
| self.assertEqual(result.shape, ()) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(_wrapper(a), result) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_pytree_input_output(self, device, dtype): |
| @make_traced |
| def fn(a, b_dict): |
| b = b_dict["b"] |
| d = {} |
| d["c"] = torch.add(a, b) |
| return (d, torch.add(a, d["c"])) |
| |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| a = make_arg((5, 5)) |
| b = make_arg((1, 5)) |
| b_dict = {"b": b} |
| |
| result_aten = fn(a, b_dict, executor="aten") |
| result_nvfuser = fn(a, b_dict, executor="nvfuser") |
| self.assertEqual(result_aten, result_nvfuser) |
| |
| @dtypes(torch.float32) |
| def test_memory_format_strides(self, device, dtype): |
| shapes = ( |
| (), |
| (0,), |
| (1,), |
| (5), |
| (1, 0), |
| (1, 1), |
| (3, 7), |
| (3, 0, 2), |
| (1, 1, 2), |
| (4, 1, 1), |
| (7, 8, 9), |
| ) |
| |
| channels_last_shapes = ( |
| (0, 0, 0, 0), |
| (1, 0, 3, 0), |
| (0, 2, 3, 5), |
| (2, 2, 2, 0), |
| (5, 4, 3, 2), |
| (8, 8, 7, 2), |
| (9, 1, 3, 1), |
| (4, 5, 8, 7) |
| ) |
| |
| channels_last_3d_shapes = ( |
| (0, 8, 7, 9, 2), |
| (5, 0, 7, 9, 2), |
| (5, 0, 7, 9, 0), |
| (5, 8, 7, 9, 2), |
| (5, 1, 7, 9, 2), |
| (5, 1, 7, 9, 1), |
| ) |
| |
| pairs = ( |
| (shapes, torch.contiguous_format), |
| (channels_last_shapes, torch.contiguous_format), |
| (channels_last_3d_shapes, torch.contiguous_format), |
| (channels_last_shapes, torch.channels_last), |
| (channels_last_3d_shapes, torch.channels_last_3d), |
| ) |
| |
| for shapes, memory_format in pairs: |
| for shape in shapes: |
| # tests empty |
| expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format) |
| actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| # tests clone |
| a = torch.testing.make_tensor(shape, device=device, dtype=dtype) |
| expected = torch.clone(a, memory_format=memory_format) |
| actual = torch.clone(a, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| # tests contiguous |
| a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True) |
| expected = a.contiguous(memory_format=memory_format) |
| actual = refs.contiguous(a, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| |
| class TestPrimsBasic(TestCase): |
| def test_torch_ops(self): |
| r = make_tensor((2,), device='cpu', dtype=torch.float) |
| self.assertEqual(torch.ops.prims.sin(r), torch.sin(r)) |
| |
| r = LoggingTensor(r) |
| with capture_logs() as logs: |
| log_input("input", r) |
| prims.sin(r) |
| self.assertExpectedInline('\n'.join(logs), """\ |
| $0 = input('input') |
| $1 = torch._ops.prims.sin.default($0)""") |
| |
| def test_mul_complex(self): |
| prims.mul(torch.randn(2), 1 + 1j) |
| |
| |
| instantiate_device_type_tests(TestPrims, globals()) |
| |
| if __name__ == "__main__": |
| run_tests() |