| # Owner(s): ["module: primTorch"] |
| |
| from functools import partial |
| from itertools import product |
| import warnings |
| from warnings import catch_warnings |
| import unittest |
| |
| import torch |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCUDA, |
| skipCUDAIfRocm, |
| dtypes, |
| OpDTypes, |
| ) |
| from torch.testing._internal.common_methods_invocations import ( |
| op_db, |
| ) |
| from torch.testing._internal.common_device_type import ( |
| ops, |
| ) |
| |
| from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input |
| import torch._prims as prims |
| from torch._prims.executor import make_traced |
| import torch._refs as refs |
| |
| |
| if TEST_SCIPY: |
| import scipy.special |
| |
| NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor" |
| GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition" |
| |
| class TestPrims(TestCase): |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_broadcast_in_dim(self, device, dtype): |
| def _wrapper(a, b, broadcast_dimensions): |
| return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'strictly_nvfuser'): |
| fn = partial(traced, executor=executor) |
| # Same shape |
| shape = (5, 5) |
| a = make_arg(shape) |
| b = make_arg(shape, low=0.0, high=0.0) |
| result = fn(a, b, (0, 1)) |
| |
| self.assertEqual(result.shape, a.shape) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(a, result) |
| |
| # Error input: reordering dims |
| with self.assertRaises(Exception): |
| result = fn(a, b, (1, 0)) |
| |
| # Adding outermost dimensions |
| a = make_arg((5, 5)) |
| b = make_arg((3, 3, 5, 5), low=0.0, high=0.0) |
| result = fn(a, b, (2, 3)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.broadcast_to(b.shape), result) |
| |
| # Expands |
| a = make_arg((1, 5, 1)) |
| b = make_arg((3, 5, 7), low=0.0, high=0.0) |
| result = fn(a, b, (0, 1, 2)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.expand_as(result), result) |
| |
| # Unsqueezes |
| a = make_arg((1, 2, 3)) |
| b = make_arg((1, 2, 1, 3), low=0.0, high=0.0) |
| result = fn(a, b, (0, 1, 3)) |
| |
| self.assertEqual(result.shape, b.shape) |
| self.assertEqual(a.unsqueeze(2), result) |
| |
| # FIXME: This test exposes an issue in nvfuser |
| # Adds outermost, expands, and unsqueezes |
| """ |
| a = make_arg((1, 2, 3)) |
| b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0) |
| result = fn(a, b, (1, 3, 4)) |
| |
| self.assertEqual(result.shape, b.shape) |
| a.unsqueeze_(3) |
| a.unsqueeze_(1) |
| a.unsqueeze_(0) |
| self.assertEqual(a.expand_as(result), result) |
| """ |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_broadcast_in_dim_sum(self, device, dtype): |
| def _wrapper(a): |
| a_sum = prims.sum(a, [0, 1]) |
| a_bc = prims.broadcast_in_dim(a_sum, [], []) |
| return a_bc |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'strictly_nvfuser'): |
| fn = partial(traced, executor=executor) |
| shape = (5, 5) |
| a = make_arg(shape) |
| result = fn(a) |
| |
| self.assertEqual(result.shape, ()) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(_wrapper(a), result) |
| |
| @unittest.skipIf(not TEST_SCIPY, "SciPy not found") |
| @dtypes(torch.float64, torch.long) |
| def test_cbrt_prim(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] |
| shapes = [(), (0,), (1,), (5,)] |
| |
| try: |
| # Sets the default dtype to NumPy's default dtype of double |
| cur_default = torch.get_default_dtype() |
| torch.set_default_dtype(torch.double) |
| |
| # Tested here, as this OP is not currently exposed or tested in ATen |
| for b, s in product(batches, shapes): |
| x = make_arg(b + s) |
| y = prims.cbrt(x) |
| |
| x_np = x.cpu().numpy() |
| y_np = scipy.special.cbrt(x_np) |
| |
| self.assertEqual(y, y_np, exact_device=False) |
| finally: |
| torch.set_default_dtype(cur_default) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_impl_is_used(self, device): |
| # This test is to ensure that when the nvfuser implementation exists it is used |
| # Assuming one-to-one mapping between prims and nvfuser implementations |
| # This test is not intended to test the correctness of the nvfuser implementation |
| from torch._C._nvfuser import FusionDefinition as fd |
| |
| prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops)) |
| ops_without_nvfuser_impl = { |
| name |
| for name in prim_nvfuser_ops |
| if getattr(torch.ops.nvprims, name, None) is None |
| } |
| assert ( |
| len(ops_without_nvfuser_impl) == 0 |
| ), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ", |
| "while there exists nvfuser implementations for them.") |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_empty_fusion(self, device): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a, b, c): |
| return (a, b, c) |
| |
| gm = make_fx(func)(a, a, a) |
| |
| with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"): |
| execute(gm, a, a, a, executor="strictly_nvfuser") |
| |
| # Should pass with partitioned executor |
| out = execute(gm, a, a, a, executor="nvfuser") |
| self.assertEqual(out, (a, a, a)) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_rand_like_fusion(self, device): |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a): |
| return torch.rand_like(a) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)(a) |
| |
| out = execute(gm, a, executor="strictly_nvfuser") |
| self.assertEqual(out.size(), a.size()) |
| |
| @skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529 |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_no_args(self, device): |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.executor import execute |
| from torch._prims.nvfuser_executor import make_nvfuser_fusion |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(): |
| return torch.sigmoid(a) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)() |
| |
| with warnings.catch_warnings(record=True) as caught: |
| execute(gm, executor="strictly_nvfuser") |
| # fusion execute with no cuda input is handled by nvprim aten fallback |
| self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) |
| |
| with self.assertRaisesRegex(AssertionError, "There must be at least one argument"): |
| make_nvfuser_fusion(gm) |
| |
| with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"): |
| execute(gm, a, executor="strictly_nvfuser") |
| |
| # Should pass with partitioned executor |
| out = execute(gm, executor="nvfuser") |
| self.assertEqual(out, func()) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_constant_tensors(self, device): |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 3, device=device) |
| b = torch.randn(3, 3, device=device) |
| |
| def func(b): |
| return a + b |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)(b) |
| |
| with self.assertRaisesRegex(AssertionError, "not supported yet"): |
| execute(gm, b, executor="strictly_nvfuser") |
| |
| # Should pass with partitioned executor |
| out = execute(gm, b, executor="nvfuser") |
| self.assertEqual(out, gm(b)) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_executor_cached_noncontiguous(self, device): |
| # This test is to ensure that nvfuser computes correct results for noncontiguous tensors |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a): |
| return torch.sigmoid(a) |
| |
| with TorchRefsMode(): |
| gm = make_fx(func)(a) |
| |
| # First run to create the cache |
| execute(gm, a, executor="nvfuser") |
| |
| # a.mT is noncontiguous, but it shouldn't affect correctness |
| expected = execute(gm, a.mT, executor="aten") |
| actual = execute(gm, a.mT, executor="nvfuser") |
| self.assertEqual(expected, actual) |
| |
| def test_nvfuser_capability_context(self, device): |
| # This test is to ensure that the torch calls are replaced with refs |
| # based on the nvfuser+prims capability |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| |
| # It's assumed that digamma is not supported by nvfuser |
| # If it's ever supported, this test will need to be updated |
| self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a): |
| return torch.digamma(a) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)(a) |
| |
| # Check that the torch.digamma is not replaced with torch.ops.prims.digamma |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_aten_digamma = any( |
| torch.ops.aten.digamma.default == node.target |
| for node in call_function_nodes |
| ) |
| includes_prims_digamma = any( |
| torch.ops.prims.digamma.default == node.target |
| for node in call_function_nodes |
| ) |
| self.assertTrue(includes_aten_digamma) |
| self.assertFalse(includes_prims_digamma) |
| |
| # Check mixed case, sigmoid is replaced with refs, but digamma is not |
| def func(a): |
| return torch.sigmoid(torch.digamma(a)) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)(a) |
| |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_aten_sigmoid = any( |
| torch.ops.aten.sigmoid.default == node.target |
| for node in call_function_nodes |
| ) |
| includes_prims_digamma = any( |
| torch.ops.prims.digamma.default == node.target |
| for node in call_function_nodes |
| ) |
| includes_nvprims_exp = any( |
| torch.ops.nvprims.exp.default == node.target |
| for node in call_function_nodes |
| ) |
| self.assertFalse(includes_aten_sigmoid) |
| self.assertFalse(includes_prims_digamma) |
| self.assertTrue(includes_nvprims_exp) |
| |
| |
| def test_aten_overload_to_prims(self, device): |
| # This test is to ensure that the torch.ops.aten calls are replaced with refs |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| |
| a = torch.randn(3, 3, device=device) |
| |
| def func(a): |
| return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a)) |
| |
| with TorchRefsMode(): |
| gm = make_fx(func)(a) |
| |
| # Check that all call_function nodes are prims |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| all_prims_namespace = all( |
| node.target.name().startswith("prims") for node in call_function_nodes |
| ) |
| self.assertTrue(all_prims_namespace) |
| |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_executor_parameters(self, device): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 4, device=device) |
| |
| def func(a): |
| return torch.ops.nvprims.add(a, a) |
| |
| gm = make_fx(func)(a) |
| |
| expected = execute(gm, a, executor="aten") |
| # Shouldn't raise an error because unuseful parameters are ignored |
| params_dicts = [None, {}, {"none": None}] |
| for params in params_dicts: |
| actual = execute(gm, a, executor="nvfuser", executor_parameters=params) |
| self.assertEqual(expected, actual) |
| |
| # Check caching parameter |
| for use_cache in [True, False]: |
| params = {"use_python_fusion_cache": use_cache} |
| actual = execute(gm, a, executor="nvfuser", executor_parameters=params) |
| self.assertEqual(expected, actual) |
| |
| # Check allow_single_op_fusion parameter |
| for allow_single_op_fusion in [True, False]: |
| params = {"allow_single_op_fusion": allow_single_op_fusion} |
| actual = execute(gm, a, executor="nvfuser", executor_parameters=params) |
| self.assertEqual(expected, actual) |
| |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_executor_partitioned(self, device): |
| # This test is to ensure that nvfuser partitioned executor works correctly |
| # It's assumed that digamma is not supported by nvfuser |
| # If it's ever supported, this test will need to be updated |
| self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) |
| |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 4, device=device) |
| b = torch.rand(3, 1, device=device) |
| c = torch.rand(3, 4, device=device) |
| |
| def func(a, b, c): |
| aa = torch.digamma(a) # not supported by nvfuser |
| d = torch.add(b, c) |
| dd = torch.sqrt(d) |
| return torch.mul(aa, dd.digamma()) |
| |
| with TorchRefsMode(): |
| gm = make_fx(func)(a, b, c) |
| |
| expected = execute(gm, a, b, c, executor="aten") |
| actual = execute(gm, a, b, c, executor="nvfuser") |
| self.assertEqual(expected, actual) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| def test_nvfuser_executor_partitioned_no_partitions_error(self, device): |
| # This test is to ensure that nvfuser partitioned executor works correctly |
| # It's assumed that digamma is not supported by nvfuser |
| # If it's ever supported, this test will need to be updated |
| self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) |
| |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| from torch._prims.executor import execute |
| |
| a = torch.randn(3, 4, device=device) |
| |
| def func(a): |
| return torch.digamma(a) # not supported by nvfuser |
| |
| with TorchRefsMode(): |
| gm = make_fx(func)(a) |
| |
| with catch_warnings(record=True) as w: |
| # Trigger warning |
| execute(gm, a, executor="nvfuser") |
| # Check warning occurs |
| self.assertEqual(len(w), 1) |
| self.assertTrue("is not supported by nvFuser" in str(w[-1].message)) |
| |
| def test_nvprims(self, device): |
| # This test is to ensure that nvfuser specific prims are exposed |
| # and can be traced with make_fx |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| def func(a): |
| return torch.ops.nvprims.add(a, a) |
| |
| a = torch.randn(3, 4, device=device) |
| gm = make_fx(func)(a) |
| |
| for node in gm.graph.nodes: |
| if node.op == "call_function": |
| self.assertTrue(node.name == "add") |
| self.assertTrue(node.target == torch.ops.nvprims.add.default) |
| self.assertFalse(node.target == torch.ops.prims.add.default) |
| self.assertFalse(node.target == torch.ops.aten.add.default) |
| |
| @dtypes(torch.float32, torch.float16) |
| def test_batch_norm_backward_nvprims(self, device, dtype): |
| # This test verifies that the backward pass of batch norm is correctly decomposed into nvprims |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm |
| |
| samples_iter = sample_inputs_batch_norm(None, device, dtype, requires_grad=True) |
| sample = next(samples_iter) |
| grad = torch.randn_like(sample.input) |
| |
| def func(grad, input, weight, rm, rv, eps, train): |
| return torch.ops.aten.native_batch_norm_backward.default( |
| grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True] |
| ) |
| |
| args = sample.args |
| kwargs = sample.kwargs |
| all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']] |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(func)(*all_args) |
| |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_batch_norm_backward = any( |
| torch.ops.aten.native_batch_norm_backward.default == node.target |
| for node in call_function_nodes |
| ) |
| self.assertFalse(includes_batch_norm_backward) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| @parametrize("correction", [0, 1]) |
| def test_var(self, device, dtype, correction): |
| def _wrapper(a): |
| return prims.var(a, [0, 1], correction=correction) |
| |
| traced = make_traced(_wrapper) |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| for executor in ('aten', 'strictly_nvfuser'): |
| fn = partial(traced, executor=executor) |
| shape = (5, 5) |
| a = make_arg(shape) |
| result = fn(a) |
| |
| self.assertEqual(result.shape, ()) |
| self.assertTrue(result.is_contiguous) |
| self.assertEqual(_wrapper(a), result) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float16, torch.float32) |
| @parametrize("correction", [0, 1]) |
| @parametrize("keepdim", [True, False]) |
| def test_var_mean(self, device, dtype, correction, keepdim): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| |
| |
| def _wrapper(a): |
| return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim) |
| |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(_wrapper)(make_arg((5, 5))) |
| |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_nvprims_var_mean = any( |
| torch.ops.nvprims.var_mean.main == node.target |
| for node in call_function_nodes |
| ) |
| self.assertTrue(includes_nvprims_var_mean) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32, torch.float16) |
| def test_cpu_tensor(self, device, dtype): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch._prims.executor import execute |
| |
| def _wrapper(t0, t1, cpu_scalar): |
| return t0 + t1 + cpu_scalar |
| |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| a = make_arg((12, 1)) |
| b = make_arg((12, 12)) |
| c = torch.tensor(0.5) |
| |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(_wrapper)(a, b, c) |
| |
| with warnings.catch_warnings(record=True) as caught: |
| actual = execute(gm, a, b, c, executor="nvfuser") |
| # cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback |
| self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) |
| |
| expected = execute(gm, a, b, c, executor="aten") |
| self.assertEqual(expected, actual) |
| |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_aten_add = any( |
| torch.ops.aten.add.default == node.target |
| for node in call_function_nodes |
| ) |
| self.assertFalse(includes_aten_add) |
| |
| with warnings.catch_warnings(record=True) as caught: |
| nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser") |
| # cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning |
| self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) |
| |
| self.assertEqual(expected, nvprim_aten_fallback) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float32) |
| def test_pytree_input_output(self, device, dtype): |
| @make_traced |
| def fn(a, b_dict): |
| b = b_dict["b"] |
| d = {} |
| d["c"] = torch.add(a, b) |
| return (d, torch.add(a, d["c"])) |
| |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| a = make_arg((5, 5)) |
| b = make_arg((1, 5)) |
| b_dict = {"b": b} |
| |
| result_aten = fn(a, b_dict, executor="aten") |
| result_nvfuser = fn(a, b_dict, executor="strictly_nvfuser") |
| self.assertEqual(result_aten, result_nvfuser) |
| |
| @dtypes(torch.float32) |
| def test_memory_format_strides(self, device, dtype): |
| shapes = ( |
| (), |
| (0,), |
| (1,), |
| (5), |
| (1, 0), |
| (1, 1), |
| (3, 7), |
| (3, 0, 2), |
| (1, 1, 2), |
| (4, 1, 1), |
| (7, 8, 9), |
| ) |
| |
| channels_last_shapes = ( |
| (0, 0, 0, 0), |
| (1, 0, 3, 0), |
| (0, 2, 3, 5), |
| (2, 2, 2, 0), |
| (5, 4, 3, 2), |
| (8, 8, 7, 2), |
| (9, 1, 3, 1), |
| (4, 5, 8, 7) |
| ) |
| |
| channels_last_3d_shapes = ( |
| (0, 8, 7, 9, 2), |
| (5, 0, 7, 9, 2), |
| (5, 0, 7, 9, 0), |
| (5, 8, 7, 9, 2), |
| (5, 1, 7, 9, 2), |
| (5, 1, 7, 9, 1), |
| ) |
| |
| pairs = ( |
| (shapes, torch.contiguous_format), |
| (channels_last_shapes, torch.contiguous_format), |
| (channels_last_3d_shapes, torch.contiguous_format), |
| (channels_last_shapes, torch.channels_last), |
| (channels_last_3d_shapes, torch.channels_last_3d), |
| ) |
| |
| for shapes, memory_format in pairs: |
| for shape in shapes: |
| # tests empty |
| expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format) |
| actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| # tests clone |
| a = torch.testing.make_tensor(shape, device=device, dtype=dtype) |
| expected = torch.clone(a, memory_format=memory_format) |
| actual = torch.clone(a, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| # tests contiguous |
| a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True) |
| expected = a.contiguous(memory_format=memory_format) |
| actual = refs.contiguous(a, memory_format=memory_format) |
| self.assertEqual(expected.stride(), actual.stride()) |
| |
| @dtypes(torch.float32) |
| def test_reshape_view_method(self, device, dtype): |
| make_arg = partial(make_tensor, device=device, dtype=dtype) |
| a = make_arg((5, 5)) |
| new_shape = 1, 5, 1, 5 |
| result_eager = a.reshape(*new_shape) |
| result_refs = refs.reshape(a, *new_shape) |
| self.assertEqual(result_eager, result_refs) |
| |
| result_eager = a.view(*new_shape) |
| result_refs = refs.view(a, *new_shape) |
| self.assertEqual(result_eager, result_refs) |
| |
| |
| class TestPrimsBasic(TestCase): |
| def test_torch_ops(self): |
| r = make_tensor((2,), device='cpu', dtype=torch.float) |
| self.assertEqual(torch.ops.prims.sin(r), torch.sin(r)) |
| |
| r = LoggingTensor(r) |
| with capture_logs() as logs: |
| log_input("input", r) |
| prims.sin(r) |
| self.assertExpectedInline('\n'.join(logs), """\ |
| $0 = input('input') |
| $1 = torch._ops.prims.sin.default($0)""") |
| |
| def test_mul_complex(self): |
| prims.mul(torch.randn(2), 1 + 1j) |
| |
| |
| instantiate_device_type_tests(TestPrims, globals()) |
| |
| |
| class TestRefs(TestCase): |
| @dtypes(torch.float32) |
| def test_constant_pad_nd_memory_format(self, device, dtype): |
| # Test memory format is preserved in unambiguous cases |
| for mf, ndim in ( |
| (torch.channels_last, 4), |
| (torch.contiguous_format, 4), |
| (torch.channels_last_3d, 5), |
| (torch.contiguous_format, 5), |
| ): |
| a = torch.zeros([2] * ndim).to(memory_format=mf) |
| res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim)) |
| self.assertTrue(res.is_contiguous(memory_format=mf)) |
| |
| # Ambiguous cases |
| |
| # is_channels_last_ and is_contiguous_, results in channels_last output |
| a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1)) |
| self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) |
| self.assertTrue(a.is_contiguous()) |
| actual = refs.constant_pad_nd(a, pad=[1] * 8) |
| expect = torch.constant_pad_nd(a, pad=[1] * 8) |
| self.assertEqual(actual.stride(), expect.stride()) |
| self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last)) |
| |
| # is_channels_last_contiguous_ but not is_channels_last_, results in |
| # contiguous output |
| a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1)) |
| self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) |
| self.assertTrue(a.is_contiguous()) |
| actual = refs.constant_pad_nd(a, pad=[1] * 8) |
| expect = torch.constant_pad_nd(a, pad=[1] * 8) |
| self.assertEqual(actual.stride(), expect.stride()) |
| self.assertTrue(actual.is_contiguous()) |
| |
| |
| instantiate_device_type_tests(TestRefs, globals()) |
| |
| |
| class TestDecomp(TestCase): |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float16, torch.float32) |
| def test_decomposition_type_promotion_nvprim_amp(self, device, dtype): |
| x = torch.rand(5, device=device).to(dtype) |
| y = torch.rand(5, device=device).to(dtype) |
| |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode, _is_func_unsupported_nvfuser |
| from torch.fx.experimental.proxy_tensor import make_fx |
| op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default) |
| |
| def fn0(*arg): |
| return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, arg, {}) |
| |
| def fn1(x): |
| x = x * 2 |
| x = x @ x |
| x = x * 2 |
| return x |
| |
| self.assertFalse(fn0(x, y, 0.3, False)) |
| with TorchRefsNvfuserCapabilityMode(): |
| |
| # Autocast context has C++ level ATen calls that are hidden from |
| # TorchRefsNvfuserCapabilityMode that works only on Python level. |
| # The first call to make_fx records autocast C++ calls directly and |
| # doesn't have the chance to translate to nvprims. After the first |
| # call, "gm" contains explicit calls to torch.ops.aten and nothing |
| # is hidden, so the second call to make_fx actually translates |
| # recorded autocast dtype conversions to nvprims. |
| with torch.autocast("cuda"): |
| gm = make_fx(fn1)(x) |
| gm = make_fx(gm)(x) |
| call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) |
| includes_aten_to_copy = any( |
| torch.ops.aten._to_copy.default == node.target |
| for node in call_function_nodes |
| ) |
| self.assertFalse(includes_aten_to_copy) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @dtypes(torch.float16, torch.float32) |
| def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype): |
| # masked_fill decomposition extracts cpu scalar tensor value when |
| # filling out a cuda tensor. This triggers data-dependent control flow |
| # on TorchRefsNvfuser speculative lowering. |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| |
| x = torch.empty(2, 3, device=device).to(dtype=dtype) |
| mask = torch.ones_like(x).bool() |
| y = torch.tensor(0.3) # cpu scalar tensor |
| |
| def func(x, mask, y): |
| return torch.masked_fill(x, mask, y) |
| |
| # mimics real use-case for TorchRefsNvfuserCapabilityMode context |
| gm = make_fx(func, decomposition_table={})(x, mask, y) |
| |
| with warnings.catch_warnings(record=True) as caught: |
| with TorchRefsNvfuserCapabilityMode(): |
| gm = make_fx(gm)(x, mask, y) |
| # masked_fill decomposition fails inside `get_isolated_graphmodule` |
| self.assertTrue(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught)) |
| |
| @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one) |
| def test_decomposition_method_vararg(self, device, dtype, op): |
| # some ops have vararg variants for the methods. this tests it. |
| # we don't have tests for varargs in OpInfo, so we need to |
| # improvise this a bit. |
| # The rule for general functions (the special cases being e.g. tensor |
| # creation functions taking shapes) is that things can be vararg |
| # if the method has only one argument of sequence type. |
| # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1) |
| # as well as t.permute([0, 2, 1]) |
| # when the signature in native_functions.yaml |
| # shows arguments Tensor self, IntList dims |
| # we might need to adjust things for the factory functions or |
| # have them do their own test |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch._prims.context import TorchRefsMode |
| |
| # filter out empty tuple as that cannot be the varargs |
| sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False) |
| if (si.args[-1] if si.args else si.input)) |
| |
| # just run one test, we assume there is a suitable one in the tests |
| sample_input = next(sample_inputs) |
| all_args = (sample_input.input,) + sample_input.args |
| |
| # in general, the methods take varargs and not (always?) the function |
| # variants, the exception to this rule are the factory functions |
| if op.is_factory_function: |
| fn = op.op |
| else: |
| fn = op.method_variant |
| with TorchRefsMode(): |
| gm = make_fx(fn)(*all_args[:-1], *all_args[-1]) |
| |
| # in case we add random factory functions |
| torch.manual_seed(1) |
| res = gm(*all_args[:-1], *all_args[-1]) |
| torch.manual_seed(1) |
| expected = fn(*all_args[:-1], *all_args[-1]) |
| self.assertEqual(res, expected) |
| |
| |
| instantiate_device_type_tests(TestDecomp, globals()) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |