| # Owner(s): ["module: unknown"] |
| |
| import contextlib |
| import copy |
| import inspect |
| import itertools |
| import os |
| import re |
| import unittest |
| import warnings |
| from collections import defaultdict |
| from collections.abc import Sequence |
| from functools import partial |
| from importlib import import_module |
| from typing import Dict, List |
| |
| import torch |
| import torch._prims as prims |
| import torch.utils._pytree as pytree |
| from torch._prims.context import TorchRefsMode |
| from torch._prims_common.wrappers import _maybe_remove_out_wrapper |
| from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode |
| from torch._subclasses.fake_utils import outputs_alias_inputs |
| from torch.testing import make_tensor |
| from torch.testing._internal import composite_compliance, opinfo |
| from torch.testing._internal.common_device_type import ( |
| deviceCountAtLeast, |
| instantiate_device_type_tests, |
| onlyCPU, |
| onlyCUDA, |
| onlyNativeDeviceTypesAnd, |
| OpDTypes, |
| ops, |
| skipMeta, |
| ) |
| from torch.testing._internal.common_dtype import ( |
| all_types_and_complex_and, |
| floating_and_complex_types_and, |
| integral_types_and, |
| ) |
| from torch.testing._internal.common_methods_invocations import ( |
| BinaryUfuncInfo, |
| op_db, |
| ops_and_refs, |
| python_ref_db, |
| ReductionOpInfo, |
| ReductionPythonRefInfo, |
| skip, |
| skipOps, |
| SpectralFuncInfo, |
| UnaryUfuncInfo, |
| xfail, |
| ) |
| from torch.testing._internal.common_utils import ( |
| clone_input_helper, |
| first_sample, |
| IS_CI, |
| IS_FBCODE, |
| is_iterable_of_tensors, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| noncontiguous_like, |
| parametrize, |
| run_tests, |
| set_default_dtype, |
| skipIfTorchInductor, |
| slowTest, |
| suppress_warnings, |
| TEST_WITH_ASAN, |
| TEST_WITH_ROCM, |
| TEST_WITH_TORCHDYNAMO, |
| TEST_WITH_TORCHINDUCTOR, |
| TEST_WITH_UBSAN, |
| TestCase, |
| unMarkDynamoStrictTest, |
| ) |
| from torch.utils._python_dispatch import TorchDispatchMode |
| from torch.utils._pytree import tree_map |
| |
| |
| assert torch.get_default_dtype() == torch.float32 |
| |
| # variant testing is only done with torch.float and torch.cfloat to avoid |
| # excessive test times and maximize signal to noise ratio |
| _variant_ops = partial( |
| ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) |
| ) |
| |
| # Get names of all the operators which have ref in their entry in OpInfo (testing infra) |
| # except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py), |
| # elementwise binary operators (separately implemented in test_binary_ufuncs.py), |
| # reduction operations (separately impelemented in test_reductions.py), |
| # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) |
| _ref_test_ops = tuple( |
| filter( |
| lambda op: not isinstance( |
| op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) |
| ) |
| and op.ref is not None, |
| op_db, |
| ) |
| ) |
| |
| |
| def reduction_dtype_filter(op): |
| if ( |
| not isinstance(op, ReductionPythonRefInfo) |
| or not op.supports_out |
| or torch.int16 not in op.dtypes |
| ): |
| return False |
| return "dtype" in inspect.getfullargspec(op.op).kwonlyargs |
| |
| |
| # Create a list of operators that are a subset of _ref_test_ops but don't have a |
| # numpy ref to compare them too, If both CPU and CUDA are compared to numpy |
| # then they do not need to be compared to each other |
| _ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None] |
| |
| aten = torch.ops.aten |
| |
| |
| # Tests that apply to all operators and aren't related to any particular |
| # system |
| @unMarkDynamoStrictTest |
| class TestCommon(TestCase): |
| exact_dtype = True |
| |
| # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI |
| @classmethod |
| def tearDownClass(cls): |
| super().tearDownClass() |
| |
| if IS_CI: |
| err_msg = ( |
| "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." |
| "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" |
| ) |
| # Assure no opinfo entry has dynamic_dtypes |
| filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) |
| for op in filtered_ops: |
| fmt_str = opinfo.utils.str_format_dynamic_dtype(op) |
| err_msg += "\n" + fmt_str |
| |
| assert len(filtered_ops) == 0, err_msg |
| |
| # Validates that each OpInfo works correctly on different CUDA devices |
| @onlyCUDA |
| @deviceCountAtLeast(2) |
| @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) |
| def test_multiple_devices(self, devices, dtype, op): |
| for cuda_device_str in devices: |
| cuda_device = torch.device(cuda_device_str) |
| # NOTE: only tests on first sample |
| samples = op.sample_inputs(cuda_device, dtype) |
| sample = first_sample(self, samples) |
| result = op(sample.input, *sample.args, **sample.kwargs) |
| |
| if isinstance(result, torch.Tensor): |
| self.assertTrue(result.device == cuda_device) |
| elif is_iterable_of_tensors(result): |
| self.assertTrue(all(t.device == cuda_device for t in result)) |
| else: |
| self.skipTest( |
| "Skipped! Only supports single tensor or iterable of tensor outputs." |
| ) |
| |
| def test_pointwise_tag_coverage(self): |
| pytorch_dir = os.path.abspath(__file__ + "/../../") |
| files = [ |
| "aten/src/ATen/native/UnaryOps.cpp", |
| "aten/src/ATen/native/BinaryOps.cpp", |
| "aten/src/ATen/native/PointwiseOps.cpp", |
| "aten/src/ATen/native/TensorCompare.cpp", |
| ] |
| |
| allowed_functions = ( |
| # reduction version of these operators |
| "aten.max.default", |
| "aten.max.dim", |
| "aten.max.dim_max", |
| "aten.max.names_dim", |
| "aten.max.names_dim_max", |
| "aten.max.unary_out", |
| "aten.min.default", |
| "aten.min.dim", |
| "aten.min.dim_min", |
| "aten.min.names_dim", |
| "aten.min.names_dim_min", |
| "aten.min.unary_out", |
| # not pointwise |
| "aten.isin.Tensor_Tensor", |
| "aten.isin.Tensor_Tensor_out", |
| "aten.isin.Tensor_Scalar", |
| "aten.isin.Tensor_Scalar_out", |
| "aten.isin.Scalar_Tensor", |
| "aten.isin.Scalar_Tensor_out", |
| "aten.mode.default", |
| "aten.mode.dimname", |
| "aten.mode.dimname_out", |
| "aten.mode.values", |
| ) |
| |
| regex = re.compile(r"DEFINE_DISPATCH\(.*_stub") |
| |
| def get_opoverloadpacket_from_dispatch(kernel): |
| if hasattr(torch.ops.aten, kernel): |
| return kernel |
| if hasattr(torch.ops.aten, f"__{kernel}__"): |
| return f"__{kernel}__" |
| if hasattr(torch.ops.aten, f"special_{kernel}"): |
| return f"special_{kernel}" |
| if "_" in kernel: |
| kernel_split = kernel.split("_") |
| new_kernel = "_".join(kernel_split[:-1]) |
| if hasattr(torch.ops.aten, new_kernel): |
| return new_kernel |
| |
| # could not find op from kernel dispatch string |
| self.assertTrue(False) |
| |
| for file_name in files: |
| with open(os.path.join(pytorch_dir, file_name)) as f: |
| lines = f.read() |
| matches = regex.findall(lines) |
| for match in matches: |
| kernel = match[len("DEFINE_DISPATCH(") : -len("_stub")] |
| |
| # no op definition for it, but defined with DEFINE_DISPATCH ? |
| if kernel == "trigamma": |
| continue |
| |
| kernel = get_opoverloadpacket_from_dispatch(kernel) |
| overloadpacket = getattr(torch.ops.aten, kernel) |
| |
| for overload_name in overloadpacket.overloads(): |
| overload = getattr(overloadpacket, overload_name) |
| |
| if not torch._C._dispatch_has_kernel(overload.name()): |
| continue |
| |
| # TODO: tags are not propagated to generated overload, |
| # and there's no way of specifying them |
| if torch.Tag.generated in overload.tags: |
| continue |
| |
| if str(overload) in allowed_functions: |
| continue |
| |
| self.assertTrue(torch.Tag.pointwise in overload.tags) |
| |
| # Tests that the function and its (ndarray-accepting) reference produce the same |
| # values on the tensors from sample_inputs func for the corresponding op. |
| # This test runs in double and complex double precision because |
| # NumPy does computation internally using double precision for many functions |
| # resulting in possible equality check failures. |
| # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @suppress_warnings |
| @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) |
| def test_numpy_ref(self, device, dtype, op): |
| if ( |
| TEST_WITH_TORCHINDUCTOR |
| and op.formatted_name |
| in ("signal_windows_exponential", "signal_windows_bartlett") |
| and dtype == torch.float64 |
| and "cuda" in device |
| or "cpu" in device |
| ): # noqa: E121 |
| raise unittest.SkipTest("XXX: raises tensor-likes are not close.") |
| |
| # Sets the default dtype to NumPy's default dtype of double |
| with set_default_dtype(torch.double): |
| for sample_input in op.reference_inputs(device, dtype): |
| self.compare_with_reference( |
| op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) |
| ) |
| |
| # Tests that the cpu and gpu results are consistent |
| @onlyCUDA |
| @suppress_warnings |
| @slowTest |
| @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) |
| def test_compare_cpu(self, device, dtype, op): |
| def to_cpu(arg): |
| if isinstance(arg, torch.Tensor): |
| return arg.to(device="cpu") |
| return arg |
| |
| samples = op.reference_inputs(device, dtype) |
| |
| for sample in samples: |
| cpu_sample = sample.transform(to_cpu) |
| cuda_results = op(sample.input, *sample.args, **sample.kwargs) |
| cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) |
| |
| # output_process_fn_grad has a very unfortunate name |
| # We use this function in linalg extensively to postprocess the inputs of functions |
| # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. |
| # CPU and CUDA implementations of the SVD can return valid SVDs that are different. |
| # We use this function to compare them. |
| cuda_results = sample.output_process_fn_grad(cuda_results) |
| cpu_results = cpu_sample.output_process_fn_grad(cpu_results) |
| |
| # Lower tolerance because we are running this as a `@slowTest` |
| # Don't want the periodic tests to fail frequently |
| self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) |
| |
| # Tests that experimental Python References can propagate shape, dtype, |
| # and device metadata properly. |
| # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation. |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops(python_ref_db) |
| @skipIfTorchInductor("Takes too long for inductor") |
| def test_python_ref_meta(self, device, dtype, op): |
| CHECK_CONJ_SKIPS = { |
| torch._refs.linalg.svd, |
| } |
| |
| with FakeTensorMode() as mode: |
| pass |
| |
| def _to_tensormeta(x): |
| if isinstance(x, torch.Tensor): |
| out = FakeTensor.from_tensor(x, mode) |
| return out |
| return x |
| |
| # TODO: iterate over requires_grad true/false |
| for sample in op.reference_inputs(device, dtype, requires_grad=False): |
| result = op(sample.input, *sample.args, **sample.kwargs) |
| |
| meta_sample = sample.transform(_to_tensormeta) |
| try: |
| with mode: |
| meta_result = op( |
| meta_sample.input, *meta_sample.args, **meta_sample.kwargs |
| ) |
| except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: |
| continue |
| except torch._subclasses.fake_tensor.DataDependentOutputException: |
| continue |
| except torch._subclasses.fake_tensor.UnsupportedOperatorException: |
| continue |
| |
| if isinstance(result, torch.Tensor): |
| self.assertTrue(isinstance(meta_result, FakeTensor)) |
| prims.utils.compare_tensor_meta( |
| result, meta_result, check_conj=op.op not in CHECK_CONJ_SKIPS |
| ) |
| elif isinstance(result, Sequence): |
| for a, b in zip(result, meta_result): |
| if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): |
| self.assertTrue(isinstance(b, FakeTensor)) |
| prims.utils.compare_tensor_meta( |
| a, b, check_conj=op.op not in CHECK_CONJ_SKIPS |
| ) |
| |
| def _ref_test_helper( |
| self, |
| ctx, |
| device, |
| dtype, |
| op, |
| skip_zero_numel=False, |
| skip_zero_dim=False, |
| skip_bfloat=False, |
| skip_view_consistency=False, |
| ): |
| # NOTE: this test works by comparing the reference |
| ex = None |
| for sample in op.reference_inputs(device, dtype, requires_grad=False): |
| if ( |
| isinstance(sample.input, torch.Tensor) |
| and sample.input.numel() == 0 |
| and skip_zero_numel |
| ): |
| continue |
| if ( |
| isinstance(sample.input, torch.Tensor) |
| and sample.input.ndim == 0 |
| and skip_zero_dim |
| ): |
| continue |
| |
| if skip_bfloat and ( |
| ( |
| isinstance(sample.input, torch.Tensor) |
| and sample.input.dtype == torch.bfloat16 |
| ) |
| or any( |
| isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16 |
| for arg in sample.args |
| ) |
| ): |
| continue |
| with ctx(): |
| ref_result = op(sample.input, *sample.args, **sample.kwargs) |
| torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) |
| |
| for a, b in zip( |
| pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result) |
| ): |
| if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): |
| prims.utils.compare_tensor_meta(a, b) |
| if ( |
| getattr(op, "validate_view_consistency", True) |
| and not skip_view_consistency |
| ): |
| msg = ( |
| f"The torch implementation {'returns' if b._is_view() else 'does not return'} " |
| f"a view, while the reference {'does' if a._is_view() else 'does not'}" |
| ) |
| self.assertEqual(a._is_view(), b._is_view(), msg) |
| |
| # Computes the dtype the more precise computatino would occur in |
| precise_dtype = torch.bool |
| if prims.utils.is_integer_dtype(dtype): |
| # Note: bool and integer dtypes do not have more |
| # precise dtypes -- they simply must be close |
| precise_dtype = dtype |
| if prims.utils.is_float_dtype(dtype): |
| precise_dtype = torch.double |
| if prims.utils.is_complex_dtype(dtype): |
| precise_dtype = torch.cdouble |
| |
| # Checks if the results are close |
| try: |
| self.assertEqual( |
| ref_result, |
| torch_result, |
| exact_stride=False, |
| exact_device=True, |
| exact_layout=True, |
| exact_is_coalesced=True, |
| ) |
| except AssertionError as e: |
| # Raises the error if the precise dtype comparison wouldn't be |
| # different |
| if dtype is precise_dtype: |
| raise e |
| |
| ex = e |
| |
| # Goes to next sample if these results are close |
| if not ex: |
| continue |
| |
| # If the results are not close, checks that the |
| # reference is more accurate than the torch op |
| def _make_precise(x): |
| if isinstance(x, torch.dtype): |
| return precise_dtype |
| if isinstance(x, torch.Tensor) and x.dtype is dtype: |
| return x.to(precise_dtype) |
| return x |
| |
| precise_sample = sample.transform(_make_precise) |
| precise_result = op.torch_opinfo( |
| precise_sample.input, *precise_sample.args, **precise_sample.kwargs |
| ) |
| |
| def _distance(a, b): |
| # Special-cases boolean comparisons |
| if prims.utils.is_boolean_dtype(a.dtype): |
| assert b.dtype is torch.bool |
| return (a ^ b).sum() |
| |
| same = a == b |
| if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype( |
| a.dtype |
| ): |
| same = torch.logical_or( |
| same, torch.logical_and(torch.isnan(a), torch.isnan(b)) |
| ) |
| |
| actual_error = torch.where(same, 0, torch.abs(a - b)).sum() |
| return actual_error |
| |
| ref_distance = 0 |
| for a, b in zip( |
| pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result) |
| ): |
| ref_distance = ref_distance + _distance(a, b) |
| |
| torch_distance = 0 |
| for a, b in zip( |
| pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result) |
| ): |
| torch_distance = torch_distance + _distance(a, b) |
| |
| # TODO: consider adding some tolerance to this comparison |
| msg = ( |
| f"Reference result was farther ({ref_distance}) from the precise " |
| f"computation than the torch result was ({torch_distance})!" |
| ) |
| self.assertTrue(ref_distance <= torch_distance, msg=msg) |
| |
| # Reports numerical accuracy discrepancies |
| if ex is not None: |
| msg = "Test passed because the reference was more accurate than the torch operator." |
| warnings.warn(msg) |
| |
| # Tests that experimental Python References perform the same computation |
| # as the operators they reference, when operator calls in the torch |
| # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo). |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops(python_ref_db) |
| @skipIfTorchInductor("Takes too long for inductor") |
| def test_python_ref(self, device, dtype, op): |
| # In this test, primTorch refs call into the refs namespace |
| # For example, a ref with torch.foo in it will calls refs.foo instead |
| # Direct calls to refs and prims are not affected |
| if ( |
| TEST_WITH_ROCM |
| and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") |
| and dtype == torch.float16 |
| ): |
| self.skipTest("Skipped on ROCm") |
| self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op) |
| |
| # Tests that experimental Python References perform the same computation |
| # as the operators they reference, when operator calls in the torch |
| # namespace are preserved (torch.foo remains torch.foo). |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops(python_ref_db) |
| @skipIfTorchInductor("Takes too long for inductor") |
| def test_python_ref_torch_fallback(self, device, dtype, op): |
| # In this test, refs call into the torch namespace (after the initial invocation) |
| # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo |
| # Direct calls to refs and prims are not translated |
| if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16: |
| self.skipTest("Skipped on ROCm") |
| self._ref_test_helper(contextlib.nullcontext, device, dtype, op) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyCUDA |
| @ops(python_ref_db) |
| @parametrize("executor", ["aten"]) |
| @skipIfTorchInductor("Takes too long for inductor") |
| def test_python_ref_executor(self, device, dtype, op, executor): |
| if ( |
| TEST_WITH_ROCM |
| and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") |
| and dtype == torch.float16 |
| ): |
| self.skipTest("Skipped on ROCm") |
| # skip zero-dim tensors for some composites of reduction operations and view |
| skip_zero_dim_ops = [ |
| "_refs.logsumexp", |
| "_refs.log_softmax", |
| "_refs.native_group_norm", |
| "_refs.softmax", |
| "_refs.sum_to_size", |
| "ops.nvprims.view", |
| ] |
| |
| from copy import copy |
| |
| from torch._prims.executor import make_traced |
| |
| op = copy(op) |
| op.op = partial(make_traced(op.op), executor=executor) |
| self._ref_test_helper(contextlib.nullcontext, device, dtype, op) |
| |
| @skipMeta |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) |
| def test_errors(self, device, op): |
| error_inputs = op.error_inputs(device) |
| for ei in error_inputs: |
| si = ei.sample_input |
| with self.assertRaisesRegex(ei.error_type, ei.error_regex): |
| out = op(si.input, *si.args, **si.kwargs) |
| self.assertFalse(isinstance(out, type(NotImplemented))) |
| |
| @skipMeta |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops( |
| [op for op in op_db if op.error_inputs_sparse_func is not None], |
| dtypes=OpDTypes.none, |
| ) |
| @parametrize( |
| "layout", |
| ( |
| torch.sparse_csr, |
| torch.sparse_csc, |
| torch.sparse_bsr, |
| torch.sparse_bsc, |
| torch.sparse_coo, |
| ), |
| ) |
| def test_errors_sparse(self, device, op, layout): |
| for ei in op.error_inputs_sparse(device, layout): |
| si = ei.sample_input |
| with self.assertRaisesRegex(ei.error_type, ei.error_regex): |
| out = op(si.input, *si.args, **si.kwargs) |
| self.assertFalse(isinstance(out, type(NotImplemented))) |
| |
| @skipMeta |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops( |
| [op for op in python_ref_db if op.error_inputs_func is not None], |
| dtypes=OpDTypes.none, |
| ) |
| @skipIfTorchInductor("Takes too long for inductor") |
| def test_python_ref_errors(self, device, op): |
| mode = FakeTensorMode() |
| with mode: |
| pass |
| |
| def _to_tensormeta(x): |
| if isinstance(x, torch.Tensor): |
| return FakeTensor.from_tensor(x, mode) |
| return x |
| |
| error_inputs = op.error_inputs(device) |
| for ei in error_inputs: |
| si = ei.sample_input |
| meta_sample = si.transform(_to_tensormeta) |
| with self.assertRaisesRegex(ei.error_type, ei.error_regex): |
| op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) |
| |
| # Tests that the function produces the same result when called with |
| # noncontiguous tensors. |
| # TODO: get working with Windows by addressing failing operators |
| # TODO: get working with ASAN by addressing failing operators |
| @unittest.skipIf(IS_WINDOWS, "Skipped under Windows") |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @suppress_warnings |
| @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64)) |
| def test_noncontiguous_samples(self, device, dtype, op): |
| test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) |
| sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad) |
| for sample_input in sample_inputs: |
| t_inp, t_args, t_kwargs = ( |
| sample_input.input, |
| sample_input.args, |
| sample_input.kwargs, |
| ) |
| noncontig_sample = sample_input.noncontiguous() |
| n_inp, n_args, n_kwargs = ( |
| noncontig_sample.input, |
| noncontig_sample.args, |
| noncontig_sample.kwargs, |
| ) |
| |
| # validates forward |
| expected = op(t_inp, *t_args, **t_kwargs) |
| actual = op(n_inp, *n_args, **n_kwargs) |
| |
| self.assertEqual(actual, expected) |
| |
| # Validate backward |
| # Short-circuits if the op doesn't support grad in this device x dtype |
| if not test_grad: |
| continue |
| |
| expected = sample_input.output_process_fn_grad(expected) |
| actual = sample_input.output_process_fn_grad(actual) |
| |
| if isinstance(expected, torch.Tensor): |
| grad_for_expected = torch.randn_like(expected) |
| grad_for_actual = noncontiguous_like(grad_for_expected) |
| elif isinstance(expected, Sequence): |
| # Filter output elements that do not require grad |
| expected = [ |
| t |
| for t in expected |
| if isinstance(t, torch.Tensor) and t.requires_grad |
| ] |
| actual = [ |
| n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad |
| ] |
| grad_for_expected = [torch.randn_like(t) for t in expected] |
| grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected] |
| else: |
| # Nothing to do if it returns a scalar or things like that |
| continue |
| |
| # Concatenate inputs into a tuple |
| t_inputs = ( |
| (t_inp,) + t_args |
| if isinstance(t_inp, torch.Tensor) |
| else tuple(t_inp) + t_args |
| ) |
| n_inputs = ( |
| (n_inp,) + n_args |
| if isinstance(n_inp, torch.Tensor) |
| else tuple(n_inp) + n_args |
| ) |
| |
| # Filter the elemnts that are tensors that require grad |
| t_input_tensors = [ |
| t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad |
| ] |
| n_input_tensors = [ |
| n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad |
| ] |
| |
| self.assertEqual(len(t_input_tensors), len(n_input_tensors)) |
| |
| # Some functions may not use all the inputs to generate gradients. One of the |
| # few examples of this "odd" behaviour is F.hinge_embedding_loss |
| t_grads = torch.autograd.grad( |
| expected, t_input_tensors, grad_for_expected, allow_unused=True |
| ) |
| n_grads = torch.autograd.grad( |
| actual, n_input_tensors, grad_for_actual, allow_unused=True |
| ) |
| |
| msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}." |
| for i, (t, n) in enumerate(zip(t_grads, n_grads)): |
| self.assertEqual(t, n, msg=msg.format(i)) |
| |
| # Separates one case from the following test_out because many ops don't properly implement the |
| # incorrectly sized out parameter warning properly yet |
| # Cases test here: |
| # - out= with the correct dtype and device, but the wrong shape |
| @ops(ops_and_refs, dtypes=OpDTypes.none) |
| def test_out_warning(self, device, op): |
| if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp": |
| self.skipTest("flaky") |
| # Prefers running in float32 but has a fallback for the first listed supported dtype |
| supported_dtypes = op.supported_dtypes(self.device_type) |
| if len(supported_dtypes) == 0: |
| self.skipTest("Skipped! Op has not supported dtypes on this device.") |
| dtype = ( |
| torch.float32 |
| if torch.float32 in supported_dtypes |
| else next(iter(supported_dtypes)) |
| ) |
| |
| # Ops from python_ref_db point to python decomps that are potentially |
| # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these |
| # ops before testing to avoid clashing with OpInfo.supports_out |
| if not op.supports_out: |
| op = copy.copy(op) |
| op.op = _maybe_remove_out_wrapper(op.op) |
| |
| samples = op.sample_inputs(device, dtype) |
| for sample in samples: |
| # calls it normally to get the expected result |
| expected = op(sample.input, *sample.args, **sample.kwargs) |
| op_out = partial(op, sample.input, *sample.args, **sample.kwargs) |
| |
| # Short-circuits if output is not a single tensor or an |
| # iterable of tensors |
| if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( |
| expected, include_empty=True |
| ): |
| self.skipTest( |
| "Skipped! Only supports single tensor or iterable of tensor outputs." |
| ) |
| |
| # Validates the op doesn't support out if it claims not to |
| if not op.supports_out: |
| with self.assertRaises(Exception): |
| assert op_out(out=expected) != NotImplemented |
| return |
| |
| # A wrapper around map that works with single tensors and always |
| # instantiates the map. Used below to apply transforms to |
| # single tensor and iterable tensor outputs. |
| def _apply_out_transform(fn, out): |
| if isinstance(out, torch.Tensor): |
| return fn(out) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(map(fn, out)) |
| |
| # Extracts strides from a tensor or iterable of tensors into a tuple |
| def _extract_strides(out): |
| if isinstance(out, torch.Tensor): |
| return (out.stride(),) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(t.stride() for t in out) |
| |
| # Extracts data pointers from a tensor or iterable of tensors into a tuple |
| # NOTE: only extracts on the CPU and CUDA device types since some |
| # device types don't have storage |
| def _extract_data_ptrs(out): |
| if self.device_type != "cpu" and self.device_type != "cuda": |
| return () |
| |
| if isinstance(out, torch.Tensor): |
| return (out.data_ptr(),) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(t.data_ptr() for t in out) |
| |
| @suppress_warnings |
| def _compare_out(transform, *, compare_strides_and_data_ptrs=True): |
| out = _apply_out_transform(transform, expected) |
| original_strides = _extract_strides(out) |
| original_ptrs = _extract_data_ptrs(out) |
| |
| op_out(out=out) |
| final_strides = _extract_strides(out) |
| final_ptrs = _extract_data_ptrs(out) |
| |
| self.assertEqual(expected, out) |
| |
| if compare_strides_and_data_ptrs: |
| stride_msg = ( |
| f"Strides are not the same! Original strides were {original_strides} " |
| f"and strides are now {final_strides}" |
| ) |
| self.assertEqual(original_strides, final_strides, msg=stride_msg) |
| self.assertEqual(original_ptrs, final_ptrs) |
| |
| # Case Zero: out= with the correct dtype and device, but the wrong shape |
| # Expected behavior: if nonempty, resize with a warning. |
| def _case_zero_transform(t): |
| wrong_shape = list(t.shape) |
| |
| if len(wrong_shape) == 0: |
| # Handles scalar tensor case (empty list) |
| wrong_shape = [2] |
| else: |
| wrong_shape[-1] = wrong_shape[-1] + 1 |
| return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) |
| |
| # Verifies the out values are correct |
| _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False) |
| |
| # Additionally validates that the appropriate warning is thrown if a nonempty |
| # tensor is resized. |
| def _any_nonempty(out): |
| if isinstance(out, torch.Tensor): |
| return out.numel() > 0 |
| |
| return any(x.numel() > 0 for x in out) |
| |
| out = _apply_out_transform(_case_zero_transform, expected) |
| msg_fail = "Resized a non-empty tensor but did not warn about it." |
| if _any_nonempty(out): |
| with self.assertWarnsRegex( |
| UserWarning, "An output with one or more elements", msg=msg_fail |
| ): |
| op_out(out=out) |
| |
| # Validates ops implement the correct out= behavior |
| # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch |
| # for a description of the correct behavior |
| # Validates the following cases: |
| # - Case 0: out has the correct shape, dtype, and device but is full of extremal values |
| # - Case 1: out has the correct shape, dtype, and device but is noncontiguous |
| # - Case 2: out has the correct dtype and device, but is zero elements |
| # - Case 3: out has the correct shape and dtype, but is on a different device type |
| # - Case 4: out has the correct shape and device, but a dtype that cannot |
| # "safely" cast to |
| # |
| # Case 3 and 4 are slightly different when the op is a factory function: |
| # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out |
| # - if device, dtype are passed, device and dtype should match |
| @ops(ops_and_refs, dtypes=OpDTypes.any_one) |
| def test_out(self, device, dtype, op): |
| # Prefers running in float32 but has a fallback for the first listed supported dtype |
| samples = op.sample_inputs(device, dtype) |
| |
| # Ops from python_ref_db point to python decomps that are potentially |
| # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these |
| # ops before testing to avoid clashing with OpInfo.supports_out |
| if not op.supports_out: |
| op = copy.copy(op) |
| op.op = _maybe_remove_out_wrapper(op.op) |
| |
| for sample in samples: |
| # calls it normally to get the expected result |
| expected = op(sample.input, *sample.args, **sample.kwargs) |
| op_out = partial(op, sample.input, *sample.args, **sample.kwargs) |
| |
| # Short-circuits if output is not a single tensor or an |
| # iterable of tensors |
| if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( |
| expected, include_empty=True |
| ): |
| self.skipTest( |
| "Skipped! Only supports single tensor or iterable of tensor outputs." |
| ) |
| |
| # Validates the op doesn't support out if it claims not to |
| if not op.supports_out: |
| with self.assertRaises(Exception): |
| assert op_out(out=expected) != NotImplemented |
| return |
| |
| # A wrapper around map that works with single tensors and always |
| # instantiates the map. Used below to apply transforms to |
| # single tensor and iterable tensor outputs. |
| def _apply_out_transform(fn, out): |
| if isinstance(out, torch.Tensor): |
| return fn(out) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(map(fn, out)) |
| |
| # Extracts strides from a tensor or iterable of tensors into a tuple |
| def _extract_strides(out): |
| if isinstance(out, torch.Tensor): |
| return (out.stride(),) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(t.stride() for t in out) |
| |
| # Extracts data pointers from a tensor or iterable of tensors into a tuple |
| # NOTE: only extracts on the CPU and CUDA device types since some |
| # device types don't have storage |
| def _extract_data_ptrs(out): |
| if self.device_type != "cpu" and self.device_type != "cuda": |
| return () |
| |
| if isinstance(out, torch.Tensor): |
| return (out.data_ptr(),) |
| |
| # assumes (see above) that out is an iterable of tensors |
| return tuple(t.data_ptr() for t in out) |
| |
| def _compare_out(transform, *, compare_strides_and_data_ptrs=True): |
| out = _apply_out_transform(transform, expected) |
| original_strides = _extract_strides(out) |
| original_ptrs = _extract_data_ptrs(out) |
| |
| op_out(out=out) |
| final_strides = _extract_strides(out) |
| final_ptrs = _extract_data_ptrs(out) |
| self.assertEqual(expected, out) |
| |
| if compare_strides_and_data_ptrs: |
| stride_msg = ( |
| "Strides are not the same! " |
| f"Original strides were {original_strides} and strides are now {final_strides}" |
| ) |
| self.assertEqual(original_strides, final_strides, msg=stride_msg) |
| self.assertEqual(original_ptrs, final_ptrs) |
| |
| # Case 0: out= with the correct shape, dtype, and device |
| # but NaN values for floating point and complex tensors, and |
| # maximum values for integer tensors. |
| # Expected behavior: out= values have no effect on the computation. |
| def _case_zero_transform(t): |
| try: |
| info = torch.iinfo(t.dtype) |
| return torch.full_like(t, info.max) |
| except TypeError as te: |
| # for non-integer types fills with NaN |
| return torch.full_like(t, float("nan")) |
| |
| _compare_out(_case_zero_transform) |
| |
| # Case 1: out= with the correct shape, dtype, and device, |
| # but noncontiguous. |
| # Expected behavior: strides are respected and `out` storage is not changed. |
| def _case_one_transform(t): |
| return make_tensor( |
| t.shape, dtype=t.dtype, device=t.device, noncontiguous=True |
| ) |
| |
| _compare_out(_case_one_transform) |
| |
| # Case 2: out= with the correct dtype and device, but has no elements. |
| # Expected behavior: resize without warning. |
| def _case_two_transform(t): |
| return make_tensor((0,), dtype=t.dtype, device=t.device) |
| |
| _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) |
| |
| # Also validates that no warning is thrown when this out is resized |
| out = _apply_out_transform(_case_two_transform, expected) |
| with warnings.catch_warnings(record=True) as caught: |
| warnings.simplefilter("always") |
| op_out(out=out) |
| |
| # Verifies no warning is a resize warning |
| for w in caught: |
| if "An output with one or more elements" in str(w.message): |
| self.fail( |
| "Resizing an out= argument with no elements threw a resize warning!" |
| ) |
| |
| # Case 3: out= with correct shape and dtype, but wrong device. |
| wrong_device = None |
| if torch.device(device).type != "cpu": |
| wrong_device = "cpu" |
| elif torch.cuda.is_available(): |
| wrong_device = "cuda" |
| |
| factory_fn_msg = ( |
| "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its " |
| "OpInfo with `is_factory_function=True`." |
| ) |
| if wrong_device is not None: |
| |
| def _case_three_transform(t): |
| return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) |
| |
| out = _apply_out_transform(_case_three_transform, expected) |
| |
| if op.is_factory_function and sample.kwargs.get("device", None) is None: |
| op_out(out=out) |
| else: |
| msg_fail = ( |
| f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}." |
| ) + factory_fn_msg |
| with self.assertRaises(RuntimeError, msg=msg_fail): |
| op_out(out=out) |
| |
| # Case 4: out= with correct shape and device, but a dtype |
| # that output cannot be "safely" cast to (long). |
| # Expected behavior: error. |
| # NOTE: this case is filtered by dtype since some ops produce |
| # bool tensors, for example, which can be safely cast to any |
| # dtype. It is applied when single tensors are floating point or complex |
| # dtypes, or if an op returns multiple tensors when at least one such |
| # tensor is a floating point or complex dtype. |
| _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) |
| if ( |
| isinstance(expected, torch.Tensor) |
| and expected.dtype in _dtypes |
| or ( |
| not isinstance(expected, torch.Tensor) |
| and any(t.dtype in _dtypes for t in expected) |
| ) |
| ): |
| |
| def _case_four_transform(t): |
| return make_tensor(t.shape, dtype=torch.long, device=t.device) |
| |
| out = _apply_out_transform(_case_four_transform, expected) |
| msg_fail = "Expected RuntimeError when doing an unsafe cast!" |
| msg_fail = ( |
| msg_fail |
| if not isinstance(expected, torch.Tensor) |
| else ( |
| "Expected RuntimeError when doing an unsafe cast from a result of dtype " |
| f"{expected.dtype} into an out= with dtype torch.long" |
| ) |
| ) + factory_fn_msg |
| |
| if op.is_factory_function and sample.kwargs.get("dtype", None) is None: |
| op_out(out=out) |
| else: |
| with self.assertRaises(RuntimeError, msg=msg_fail): |
| op_out(out=out) |
| |
| @ops( |
| [ |
| op |
| for op in op_db |
| if op.supports_out and (op.supports_autograd or op.is_factory_function) |
| ], |
| dtypes=OpDTypes.supported, |
| allowed_dtypes=[torch.float, torch.cfloat], |
| ) |
| def test_out_requires_grad_error(self, device, dtype, op): |
| sample = first_sample(self, op.sample_inputs(device, dtype)) |
| |
| # Call op to get prototype for out arguments |
| expect = op(sample.input, *sample.args, **sample.kwargs) |
| any_requires_grad = False |
| |
| def set_requires_grad(x): |
| nonlocal any_requires_grad |
| if isinstance(x, torch.Tensor) and ( |
| x.is_floating_point() or x.is_complex() |
| ): |
| any_requires_grad = True |
| x.requires_grad_(True) |
| return x |
| |
| out = pytree.tree_map_(set_requires_grad, expect) |
| if not any_requires_grad: |
| # Skip ops without any floating point outputs, e.g. isnan |
| return |
| |
| msg = ( |
| "functions with out=... arguments don't support automatic " |
| "differentiation, but one of the arguments requires grad." |
| ) |
| with self.assertRaises(RuntimeError, msg=msg): |
| op(sample.input, *sample.args, **sample.kwargs, out=out) |
| |
| @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) |
| def test_out_integral_dtype(self, device, dtype, op): |
| def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): |
| out = None |
| try: |
| if with_out: |
| out = torch.empty(0, dtype=torch.int32, device=device) |
| op_to_test(inputs, *args, out=out, **kwargs) |
| else: |
| out = op_to_test(inputs, *args, **kwargs) |
| self.assertFalse(expectFail) |
| except RuntimeError as err: |
| self.assertEqual( |
| str(err), "dtype argument and out dtype must match in reduction" |
| ) |
| self.assertTrue(expectFail) |
| return out |
| |
| samples = op.sample_inputs(device, dtype) |
| for sample in samples: |
| if "dtype" not in sample.kwargs: |
| helper(False, False, op, sample.input, *sample.args, **sample.kwargs) |
| helper(True, False, op, sample.input, *sample.args, **sample.kwargs) |
| sample.kwargs["dtype"] = torch.int16 |
| helper(False, False, op, sample.input, *sample.args, **sample.kwargs) |
| helper(True, True, op, sample.input, *sample.args, **sample.kwargs) |
| sample.kwargs["dtype"] = torch.int32 |
| helper(False, False, op, sample.input, *sample.args, **sample.kwargs) |
| helper(True, False, op, sample.input, *sample.args, **sample.kwargs) |
| else: |
| helper(False, False, op, sample.input, *sample.args, **sample.kwargs) |
| helper( |
| True, |
| sample.kwargs["dtype"] != torch.int32, |
| op, |
| sample.input, |
| *sample.args, |
| **sample.kwargs, |
| ) |
| |
| # Tests that the forward and backward passes of operations produce the |
| # same values for the cross-product of op variants (method, inplace) |
| # against eager's gold standard op function variant |
| @_variant_ops(op_db) |
| def test_variant_consistency_eager(self, device, dtype, op): |
| # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) |
| |
| method = op.method_variant |
| inplace = op.inplace_variant |
| operator = op.operator_variant |
| inplace_operator = op.inplace_operator_variant |
| |
| # list of all inplace ops: inplace variant + alias inplace variants if exist |
| inplace_ops = [inplace, inplace_operator] |
| variants = [method, inplace, operator, inplace_operator] |
| operators = [operator, inplace_operator] |
| |
| for a_op in op.aliases: |
| variants.append(a_op.op) |
| variants.append(a_op.method_variant) |
| variants.append(a_op.inplace_variant) |
| inplace_ops.append(a_op.inplace_variant) |
| |
| inplace_variants = tuple(filter(None, inplace_ops)) |
| variants = tuple(filter(None, variants)) |
| operators = tuple(filter(None, operators)) |
| |
| _requires_grad = dtype in op.supported_backward_dtypes( |
| torch.device(device).type |
| ) |
| |
| include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex |
| samples = op.sample_inputs( |
| device, |
| dtype, |
| requires_grad=_requires_grad, |
| include_conjugated_inputs=include_conjugated_inputs, |
| ) |
| samples = list(samples) |
| |
| def _test_consistency_helper(samples, variants): |
| for sample in samples: |
| # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList |
| tensor = ( |
| sample.input |
| if isinstance(sample.input, torch.Tensor) |
| else sample.input[0] |
| ) |
| |
| # Computes function forward and backward values |
| tensor.grad = None |
| expected_forward = op(sample.input, *sample.args, **sample.kwargs) |
| expected_grad = None |
| |
| output_process_fn_grad = ( |
| sample.output_process_fn_grad |
| if sample.output_process_fn_grad |
| else lambda x: x |
| ) |
| |
| # Skips inplace variants if the output dtype is not the same as |
| # the input dtype |
| skip_inplace = False |
| if ( |
| isinstance(expected_forward, torch.Tensor) |
| and expected_forward.dtype is not tensor.dtype |
| ): |
| skip_inplace = True |
| |
| # TODO: backward consistency only supported for single tensor outputs |
| # TODO: backward consistency only checked on sample.input, not all |
| # tensor inputs |
| # TODO: update to handle checking grads of all tensor inputs as |
| # derived from each tensor output |
| if isinstance( |
| expected_forward, torch.Tensor |
| ) and dtype in op.supported_backward_dtypes(torch.device(device).type): |
| out = output_process_fn_grad(expected_forward).sum() |
| if out.dtype.is_complex: |
| out = out.abs() |
| out.backward() |
| expected_grad = tensor.grad |
| |
| # Test eager consistency |
| for variant in variants: |
| # Skips inplace ops |
| if variant in inplace_ops and skip_inplace: |
| continue |
| |
| # Compares variant's forward |
| # Note: copies the to-be-modified input when testing the inplace variant |
| tensor.grad = None |
| cloned = ( |
| clone_input_helper(sample.input) |
| if variant in inplace_ops |
| else sample.input |
| ) |
| |
| if variant in inplace_ops and sample.broadcasts_input: |
| with self.assertRaises( |
| RuntimeError, |
| msg=( |
| "inplace variant either incorrectly allowed " |
| f"resizing or you have marked the sample {sample.summary()}" |
| " incorrectly with `broadcasts_self=True" |
| ), |
| ): |
| variant_forward = variant( |
| cloned, *sample.args, **sample.kwargs |
| ) |
| continue |
| |
| if variant in operators and sample.kwargs: |
| # skip samples with kwargs for operator variants |
| continue |
| |
| variant_forward = variant(cloned, *sample.args, **sample.kwargs) |
| self.assertEqual(expected_forward, variant_forward) |
| |
| # Compares variant's backward |
| if expected_grad is not None and ( |
| variant not in inplace_ops or op.supports_inplace_autograd |
| ): |
| out = output_process_fn_grad(variant_forward).sum() |
| if out.dtype.is_complex: |
| out = out.abs() |
| out.backward() |
| self.assertEqual(expected_grad, tensor.grad) |
| |
| _test_consistency_helper(samples, variants) |
| |
| def _test_inplace_preserve_storage(samples, variants): |
| for sample in samples: |
| # Skips inplace variants if the output dtype is not the same as |
| # the input dtype |
| expected_forward = op(sample.input, *sample.args, **sample.kwargs) |
| tensor = ( |
| sample.input |
| if isinstance(sample.input, torch.Tensor) |
| else sample.input[0] |
| ) |
| skip_inplace = False |
| if ( |
| isinstance(expected_forward, torch.Tensor) |
| and expected_forward.dtype is not tensor.dtype |
| ): |
| skip_inplace = True |
| if skip_inplace: |
| return |
| for variant in variants: |
| cloned = ( |
| clone_input_helper(sample.input) |
| if variant in inplace_ops |
| else sample.input |
| ) |
| inp_tensor = ( |
| cloned if isinstance(cloned, torch.Tensor) else cloned[0] |
| ) |
| data_ptr = inp_tensor.data_ptr() |
| if variant in operators and sample.kwargs: |
| # skip samples with kwargs for operator variants |
| continue |
| |
| variant_forward = variant(cloned, *sample.args, **sample.kwargs) |
| # TODO Support non-tensor outputs if they exist for inplace ops |
| if isinstance(variant_forward, torch.Tensor): |
| self.assertEqual( |
| data_ptr, variant_forward.data_ptr(), atol=0, rtol=0 |
| ) |
| else: |
| self.assertTrue( |
| False, |
| "Non-tensor outputs for inplace ops are not supported", |
| ) |
| |
| if len(inplace_ops) > 0: |
| inplace_samples = list( |
| filter(lambda sample: not sample.broadcasts_input, samples) |
| ) |
| _test_inplace_preserve_storage(inplace_samples, inplace_variants) |
| |
| # Reference testing for operations in complex32 against complex64. |
| # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. |
| @ops(op_db, allowed_dtypes=(torch.complex32,)) |
| def test_complex_half_reference_testing(self, device, dtype, op): |
| if not op.supports_dtype(torch.complex32, device): |
| unittest.skip("Does not support complex32") |
| |
| for sample in op.sample_inputs(device, dtype): |
| actual = op(sample.input, *sample.args, **sample.kwargs) |
| # sample.transform applies the lambda to torch.Tensor and torch.dtype. |
| # However, we only want to apply it to Tensors with dtype `torch.complex32`.. |
| transformed_sample = sample.transform( |
| lambda x: x.to(torch.complex64) |
| if isinstance(x, torch.Tensor) and x.dtype is torch.complex32 |
| else x |
| ) |
| expected = op( |
| transformed_sample.input, |
| *transformed_sample.args, |
| **transformed_sample.kwargs, |
| ) |
| # Since range of chalf is much less compared to cfloat, |
| # we get `inf`s easily (eg. with `pow`, `exp`), |
| # so we cast `cfloat` back to `chalf`. |
| expected = tree_map( |
| lambda x: x.to(torch.complex32) |
| if isinstance(x, torch.Tensor) and x.dtype is torch.complex64 |
| else x, |
| expected, |
| ) |
| |
| # `exact_dtype` is False because for ops like real, imag |
| # we get different dtypes for `actual` and `expected` |
| # `chalf` input -> `half` output |
| # `cfloat` input -> `float` output |
| self.assertEqual(actual, expected, exact_dtype=False) |
| |
| @ops(op_db, allowed_dtypes=(torch.bool,)) |
| @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") |
| def test_non_standard_bool_values(self, device, dtype, op): |
| # Test boolean values other than 0x00 and 0x01 (gh-54789) |
| def convert_boolean_tensors(x): |
| if not isinstance(x, torch.Tensor) or x.dtype != torch.bool: |
| return x |
| |
| # Map False -> 0 and True -> Random value in [2, 255] |
| true_vals = torch.randint( |
| 2, 255, x.shape, dtype=torch.uint8, device=x.device |
| ) |
| false_vals = torch.zeros((), dtype=torch.uint8, device=x.device) |
| x_int = torch.where(x, true_vals, false_vals) |
| |
| ret = x_int.view(torch.bool) |
| self.assertEqual(ret, x) |
| return ret |
| |
| for sample in op.sample_inputs(device, dtype): |
| expect = op(sample.input, *sample.args, **sample.kwargs) |
| |
| transformed = sample.transform(convert_boolean_tensors) |
| actual = op(transformed.input, *transformed.args, **transformed.kwargs) |
| |
| self.assertEqual(expect, actual) |
| |
| # Validates that each OpInfo specifies its forward and backward dtypes |
| # correctly for CPU and CUDA devices |
| @skipMeta |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops(ops_and_refs, dtypes=OpDTypes.none) |
| def test_dtypes(self, device, op): |
| # Check complex32 support only if the op claims. |
| # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. |
| device_type = torch.device(device).type |
| include_complex32 = ( |
| (torch.complex32,) |
| if op.supports_dtype(torch.complex32, device_type) |
| else () |
| ) |
| |
| # dtypes to try to backward in |
| allowed_backward_dtypes = floating_and_complex_types_and( |
| *((torch.half, torch.bfloat16) + include_complex32) |
| ) |
| |
| # lists for (un)supported dtypes |
| supported_dtypes = set() |
| unsupported_dtypes = set() |
| supported_backward_dtypes = set() |
| unsupported_backward_dtypes = set() |
| dtype_error: Dict[torch.dtype, Exception] = {} |
| |
| def unsupported(dtype, e): |
| dtype_error[dtype] = e |
| unsupported_dtypes.add(dtype) |
| if dtype in allowed_backward_dtypes: |
| unsupported_backward_dtypes.add(dtype) |
| |
| for dtype in all_types_and_complex_and( |
| *((torch.half, torch.bfloat16, torch.bool) + include_complex32) |
| ): |
| # tries to acquire samples - failure indicates lack of support |
| requires_grad = dtype in allowed_backward_dtypes |
| try: |
| samples = tuple( |
| op.sample_inputs(device, dtype, requires_grad=requires_grad) |
| ) |
| except Exception as e: |
| unsupported(dtype, e) |
| continue |
| |
| for sample in samples: |
| # tries to call operator with the sample - failure indicates |
| # lack of support |
| try: |
| result = op(sample.input, *sample.args, **sample.kwargs) |
| supported_dtypes.add(dtype) |
| except Exception as e: |
| # NOTE: some ops will fail in forward if their inputs |
| # require grad but they don't support computing the gradient |
| # in that type! This is a bug in the op! |
| unsupported(dtype, e) |
| continue |
| |
| # Checks for backward support in the same dtype, if the input has |
| # one or more tensors requiring grad |
| def _tensor_requires_grad(x): |
| if isinstance(x, dict): |
| for v in x.values(): |
| if _tensor_requires_grad(v): |
| return True |
| if isinstance(x, (list, tuple)): |
| for a in x: |
| if _tensor_requires_grad(a): |
| return True |
| if isinstance(x, torch.Tensor) and x.requires_grad: |
| return True |
| |
| return False |
| |
| requires_grad = ( |
| _tensor_requires_grad(sample.input) |
| or _tensor_requires_grad(sample.args) |
| or _tensor_requires_grad(sample.kwargs) |
| ) |
| if not requires_grad: |
| continue |
| |
| try: |
| result = sample.output_process_fn_grad(result) |
| if isinstance(result, torch.Tensor): |
| backward_tensor = result |
| elif isinstance(result, Sequence) and isinstance( |
| result[0], torch.Tensor |
| ): |
| backward_tensor = result[0] |
| else: |
| continue |
| |
| # Note: this grad may not have the same dtype as dtype |
| # For functions like complex (float -> complex) or abs |
| # (complex -> float) the grad tensor will have a |
| # different dtype than the input. |
| # For simplicity, this is still modeled as these ops |
| # supporting grad in the input dtype. |
| grad = torch.randn_like(backward_tensor) |
| backward_tensor.backward(grad) |
| supported_backward_dtypes.add(dtype) |
| except Exception as e: |
| dtype_error[dtype] = e |
| unsupported_backward_dtypes.add(dtype) |
| |
| # Checks that dtypes are listed correctly and generates an informative |
| # error message |
| |
| supported_forward = supported_dtypes - unsupported_dtypes |
| partially_supported_forward = supported_dtypes & unsupported_dtypes |
| unsupported_forward = unsupported_dtypes - supported_dtypes |
| supported_backward = supported_backward_dtypes - unsupported_backward_dtypes |
| partially_supported_backward = ( |
| supported_backward_dtypes & unsupported_backward_dtypes |
| ) |
| unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes |
| |
| device_type = torch.device(device).type |
| |
| claimed_forward = set(op.supported_dtypes(device_type)) |
| supported_but_unclaimed_forward = supported_forward - claimed_forward |
| claimed_but_unsupported_forward = claimed_forward & unsupported_forward |
| |
| claimed_backward = set(op.supported_backward_dtypes(device_type)) |
| supported_but_unclaimed_backward = supported_backward - claimed_backward |
| claimed_but_unsupported_backward = claimed_backward & unsupported_backward |
| |
| # Partially supporting a dtype is not an error, but we print a warning |
| if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: |
| msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n" |
| if len(partially_supported_forward) > 0: |
| msg = ( |
| msg |
| + f"The following dtypes only worked on some samples during forward: {partially_supported_forward}.\n" |
| ) |
| if len(partially_supported_backward) > 0: |
| msg = ( |
| msg |
| + f"The following dtypes only worked on some samples during backward: {partially_supported_backward}.\n" |
| ) |
| print(msg) |
| |
| if ( |
| len(supported_but_unclaimed_forward) |
| + len(claimed_but_unsupported_forward) |
| + len(supported_but_unclaimed_backward) |
| + len(claimed_but_unsupported_backward) |
| ) == 0: |
| return |
| |
| # Reference operators often support additional dtypes, and that's OK |
| if op in python_ref_db: |
| if ( |
| len(claimed_but_unsupported_forward) |
| + len(claimed_but_unsupported_backward) |
| ) == 0: |
| return |
| |
| # Generates error msg |
| msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n" |
| if len(supported_but_unclaimed_forward) > 0: |
| msg = ( |
| msg |
| + "The following dtypes worked in forward but are not listed by the OpInfo: " |
| + f"{supported_but_unclaimed_forward}.\n" |
| ) |
| if len(supported_but_unclaimed_backward) > 0: |
| msg = ( |
| msg |
| + "The following dtypes worked in backward but are not listed by the OpInfo: " |
| + f"{supported_but_unclaimed_backward}.\n" |
| ) |
| if len(claimed_but_unsupported_forward) > 0: |
| msg = ( |
| msg |
| + "The following dtypes did not work in forward but are listed by the OpInfo: " |
| + f"{claimed_but_unsupported_forward}.\n" |
| ) |
| if len(claimed_but_unsupported_backward) > 0: |
| msg = ( |
| msg |
| + "The following dtypes did not work in backward " |
| + f"but are listed by the OpInfo: {claimed_but_unsupported_backward}.\n" |
| ) |
| |
| all_claimed_but_unsupported = set.union( |
| claimed_but_unsupported_backward, claimed_but_unsupported_forward |
| ) |
| if all_claimed_but_unsupported: |
| msg += "Unexpected failures raised the following errors:\n" |
| for dtype in all_claimed_but_unsupported: |
| msg += f"{dtype} - {dtype_error[dtype]}\n" |
| |
| self.fail(msg) |
| |
| # Validates that each OpInfo that sets promotes_int_to_float=True does as it says |
| @skipMeta |
| @onlyNativeDeviceTypesAnd(["hpu"]) |
| @ops( |
| (op for op in op_db if op.promotes_int_to_float), |
| allowed_dtypes=integral_types_and(torch.bool), |
| ) |
| def test_promotes_int_to_float(self, device, dtype, op): |
| for sample in op.sample_inputs(device, dtype): |
| output = op(sample.input, *sample.args, **sample.kwargs) |
| if not output.dtype.is_floating_point: |
| self.fail( |
| f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." |
| ) |
| |
| |
| @unMarkDynamoStrictTest |
| class TestCompositeCompliance(TestCase): |
| # Checks if the operator (if it is composite) is written to support most |
| # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance" |
| # in aten/src/ATen/native/README.md for more details |
| @unittest.skipIf( |
| IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" |
| ) |
| @ops(op_db, allowed_dtypes=(torch.float,)) |
| def test_operator(self, device, dtype, op): |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| |
| for sample in samples: |
| args = [sample.input] + list(sample.args) |
| kwargs = sample.kwargs |
| composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual) |
| composite_compliance.check_all_permutations( |
| op, args, kwargs, self.assertEqual |
| ) |
| |
| @unittest.skipIf( |
| IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" |
| ) |
| @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) |
| def test_backward(self, device, dtype, op): |
| samples = op.sample_inputs(device, dtype, requires_grad=True) |
| |
| for sample in samples: |
| args = [sample.input] + list(sample.args) |
| kwargs = sample.kwargs |
| # We pass assertEqual so that decorators like `toleranceOverride` |
| # actually work (otherwise they silently do nothing!) |
| composite_compliance.check_backward_formula( |
| op.get_op(), |
| args, |
| kwargs, |
| sample.output_process_fn_grad, |
| op.gradcheck_wrapper, |
| self.assertEqual, |
| ) |
| |
| @unittest.skipIf( |
| IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" |
| ) |
| @ops(op_db, allowed_dtypes=(torch.float,)) |
| def test_forward_ad(self, device, dtype, op): |
| if torch.float not in op.supported_backward_dtypes(device): |
| raise unittest.SkipTest("Does not support autograd") |
| |
| if not op.supports_forward_ad: |
| raise unittest.SkipTest("Does not support forward_ad") |
| |
| samples = op.sample_inputs(device, dtype, requires_grad=True) |
| |
| for sample in samples: |
| args = [sample.input] + list(sample.args) |
| kwargs = sample.kwargs |
| # We pass assertEqual so that decorators like `toleranceOverride` |
| # actually work (otherwise they silently do nothing!) |
| composite_compliance.check_forward_ad_formula( |
| op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual |
| ) |
| |
| @ops(op_db, allowed_dtypes=(torch.float,)) |
| def test_cow_input(self, device, dtype, op): |
| samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) |
| |
| def is_strided_tensor(arg): |
| return torch.is_tensor(arg) and arg.layout == torch.strided |
| |
| def check_ignore_materialize(idx_or_kw, allow_list): |
| return (allow_list is not None) and (idx_or_kw in allow_list) |
| |
| def check_cow_input( |
| arg, |
| arg_copy, |
| idx_or_kw, |
| backward_or_forward="forward", |
| supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward, |
| allow_list=op.allow_cow_input_materialize_forward, |
| ): |
| arg_name = ( |
| f"Argument {idx_or_kw}" |
| if isinstance(idx_or_kw, int) |
| else f"Keyword argument '{idx_or_kw}'" |
| ) + f" during {backward_or_forward} call" |
| |
| if is_strided_tensor(arg): |
| is_cow = torch._C._is_cow_tensor(arg) |
| |
| if supports_cow_input_no_materialize and not check_ignore_materialize( |
| idx_or_kw, allow_list |
| ): |
| self.assertTrue( |
| is_cow, |
| msg=( |
| f"{arg_name} unexpectedly materializes. " |
| f"Either set `supports_cow_input_no_materialize_{backward_or_forward}=False` " |
| "in this operation's OpInfo, add the arg to the OpInfo's " |
| f"`allow_cow_input_materialize_{backward_or_forward}` list, or change the " |
| "implementation to avoid materialization." |
| ), |
| ) |
| |
| if is_cow: |
| self.assertTrue( |
| torch.allclose(arg, arg_copy, rtol=0, atol=0, equal_nan=True), |
| msg=( |
| f"{arg_name} avoided materialization, " |
| "but the operation mutated its data." |
| ), |
| ) |
| |
| for sample in samples: |
| args_raw = [sample.input] + list(sample.args) |
| kwargs_raw = sample.kwargs |
| args_copy = [] |
| args = [] |
| kwargs_copy = {} |
| kwargs = {} |
| |
| # Convert strided tensor inputs to COW tensors and make copies of |
| # all inputs |
| for idx, arg in enumerate(args_raw): |
| if is_strided_tensor(arg): |
| args_copy.append(arg.clone().detach()) |
| args.append(torch._lazy_clone(arg)) |
| else: |
| if torch.is_tensor(arg): |
| args_copy.append(arg.clone().detach()) |
| else: |
| args_copy.append(copy.deepcopy(arg)) |
| args.append(arg) |
| |
| for kw, arg in kwargs_raw.items(): |
| if is_strided_tensor(arg): |
| kwargs_copy[kw] = arg.clone().detach() |
| kwargs[kw] = torch._lazy_clone(arg) |
| else: |
| if torch.is_tensor(arg): |
| kwargs_copy[kw] = arg.clone().detach() |
| else: |
| kwargs_copy[kw] = copy.deepcopy(arg) |
| kwargs[kw] = arg |
| |
| leaf_tensors = composite_compliance.gather_leaf_tensors(args, kwargs) |
| |
| # Call forward op |
| results_raw = op.get_op()(*args, **kwargs) |
| |
| # Check that COW inputs remain COW after the forward op is executed |
| for idx, arg in enumerate(args): |
| check_cow_input(arg, args_copy[idx], idx) |
| |
| for kw, arg in kwargs.items(): |
| check_cow_input(arg, kwargs_copy[kw], kw) |
| |
| # Call backward op if it is supported. This part of the test is |
| # based on `composite_compliance.check_backward_formula` |
| if ( |
| op.supports_autograd |
| and len(leaf_tensors) > 0 |
| and not op.skip_cow_input_backward |
| ): |
| if sample.output_process_fn_grad is not None: |
| results_raw = sample.output_process_fn_grad(results_raw) |
| |
| leaf_results = pytree.tree_leaves(results_raw) |
| results = [ |
| r |
| for r in leaf_results |
| if isinstance(r, torch.Tensor) and r.requires_grad |
| ] |
| |
| all_results_strided = all( |
| is_strided_tensor(result) for result in results |
| ) |
| |
| # Only test backward if the results are strided tensors |
| if all_results_strided: |
| output_grads_raw = [ |
| torch.ones(r.shape, device=r.device, dtype=r.dtype) |
| for r in results |
| ] |
| output_grads_copy = [] |
| output_grads = [] |
| |
| # Convert output grads to COW tensors and make copies |
| for output_grad in output_grads_raw: |
| output_grads_copy.append(output_grad.clone().detach()) |
| output_grads.append(torch._lazy_clone(output_grad)) |
| |
| input_grads = torch.autograd.grad( |
| results, |
| leaf_tensors, |
| output_grads, |
| allow_unused=True, |
| retain_graph=True, |
| ) |
| |
| # Check that COW inputs remain COW after the backward op is executed |
| for idx, arg in enumerate(args): |
| check_cow_input( |
| arg, |
| args_copy[idx], |
| idx, |
| backward_or_forward="backward", |
| supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, |
| allow_list=op.allow_cow_input_materialize_backward, |
| ) |
| |
| # Check that COW inputs remain COW after the backward op is executed |
| for idx, output_grad in enumerate(output_grads): |
| check_cow_input( |
| output_grad, |
| output_grads_copy[idx], |
| f"output grad {idx}", |
| backward_or_forward="backward", |
| supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, |
| allow_list=op.allow_cow_input_materialize_backward, |
| ) |
| |
| @ops(op_db, allowed_dtypes=(torch.float,)) |
| def test_view_replay(self, device, dtype, op): |
| def _assert_match_metadata(a, b): |
| self.assertEqual(a.size(), b.size()) |
| self.assertEqual(a.stride(), b.stride()) |
| self.assertEqual(a.storage_offset(), b.storage_offset()) |
| self.assertEqual(a.device, b.device) |
| self.assertEqual(a.dtype, b.dtype) |
| |
| # ensure view replay is enabled |
| with torch.autograd._force_original_view_tracking(True): |
| for sample in op.sample_inputs(device, dtype, requires_grad=False): |
| inp = sample.input |
| outs = op(inp, *sample.args, **sample.kwargs) |
| if not isinstance(outs, (tuple, List)): |
| outs = [outs] |
| |
| # for all outputs that are views of the input, we should be able to replay the |
| # forward and reverse views via a functioning view_func() / rev_view_func(). |
| for out in outs: |
| if not ( |
| isinstance(out, torch.Tensor) |
| and out._is_view() |
| and out._base is inp |
| ): |
| continue |
| |
| # forward view_func |
| new_inp = inp.clone() |
| _assert_match_metadata(new_inp, inp) |
| new_out = out._view_func_unsafe(new_inp) |
| _assert_match_metadata(new_out, out) |
| self.assertEqual(new_out, out) |
| |
| # reverse view_func |
| new_out = out.detach() |
| new_inp = out._rev_view_func_unsafe(new_out) |
| _assert_match_metadata(new_inp, inp) |
| self.assertTrue(new_inp._is_view()) |
| self.assertTrue(new_inp._base is new_out) |
| |
| |
| @unMarkDynamoStrictTest |
| class TestMathBits(TestCase): |
| # Tests that |
| # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors |
| # produces the same value |
| # 2. The gradients are same in both cases mentioned in (1) |
| # 3. If the operator's inplace variant is supported, tests that the inplace operation |
| # produces the correct value when called on a conjugate/negative view tensor and that the output |
| # has its conj/neg bit set to true |
| # This test only runs for C -> R and C -> C functions |
| # TODO: add tests for `R->C` functions |
| # Note: This test runs for functions that take both tensors and tensorlists as input. |
| def _test_math_view( |
| self, |
| device, |
| dtype, |
| op, |
| samples, |
| math_op_physical, |
| math_op_view, |
| is_bit_set, |
| out_type, |
| ): |
| inplace_variant = op.inplace_variant |
| |
| # helper function to clone and conjugate/negate the input if its a tensor |
| # else clone the sequence and conjugate/negate the first element in the sequence |
| # If a requires_grad argument is provided the tensor being conjugated/negated will |
| # have its requires_grad set to that value. |
| def clone_and_perform_view(input, **kwargs): |
| if isinstance(input, torch.Tensor): |
| requires_grad = kwargs.get("requires_grad", input.requires_grad) |
| with torch.no_grad(): |
| # Ensure view represents the original sample input |
| input = math_op_physical(input) |
| # Note: .conj() is not called under no_grad mode since it's not allowed to modify a |
| # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj |
| # before resetting the requires_grad field for input |
| input = math_op_view(input) |
| assert input.is_leaf |
| return input.requires_grad_(requires_grad) |
| |
| if isinstance(input, Sequence): |
| out = list(map(clone_input_helper, input)) |
| out[0] = clone_and_perform_view(out[0]) |
| return tuple(out) |
| |
| for sample in samples: |
| tensor = ( |
| sample.input |
| if isinstance(sample.input, torch.Tensor) |
| else sample.input[0] |
| ) |
| cloned1 = clone_and_perform_view(sample.input) |
| |
| # Computes function forward value with a physically conjugated/negated tensor and |
| # a conj/neg view tensor and verifies that the output in both case are equal. |
| expected_forward = op(sample.input, *sample.args, **sample.kwargs) |
| forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs) |
| self.assertEqual(expected_forward, forward_with_mathview) |
| |
| # If the op has an inplace variant, and the input doesn't require broadcasting |
| # and has the same dtype as output, verify that the inplace operation on a conjugated/negated |
| # input produces correct output, and the output tensor has the conj/neg bit set to True |
| if inplace_variant is not None and not sample.broadcasts_input: |
| cloned2 = clone_and_perform_view(tensor, requires_grad=False) |
| if ( |
| isinstance(expected_forward, torch.Tensor) |
| and expected_forward.dtype is tensor.dtype |
| ): |
| inplace_forward = inplace_variant( |
| cloned2, *sample.args, **sample.kwargs |
| ) |
| self.assertTrue(is_bit_set(inplace_forward)) |
| self.assertEqual(inplace_forward, expected_forward) |
| |
| # TODO: backward consistency only supported for single tensor outputs |
| # TODO: backward consistency only checked on sample.input, not all |
| # tensor inputs |
| # TODO: update to handle checking grads of all tensor inputs as |
| # derived from each tensor output |
| if ( |
| isinstance(expected_forward, torch.Tensor) |
| and expected_forward.requires_grad |
| ): |
| output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x) |
| expected_forward = output_process_fn_grad(expected_forward) |
| forward_with_mathview = output_process_fn_grad(forward_with_mathview) |
| |
| tensor = ( |
| sample.input |
| if isinstance(sample.input, torch.Tensor) |
| else sample.input[0] |
| ) |
| expected_forward.sum().abs().backward(retain_graph=True) |
| forward_with_mathview.sum().abs().backward(retain_graph=True) |
| if tensor.grad is not None: |
| cloned1_tensor = ( |
| cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] |
| ) |
| self.assertEqual(tensor.grad, cloned1_tensor.grad) |
| |
| tensor.grad, cloned1_tensor.grad = None, None |
| |
| # a repeat of the above test if output is not complex valued |
| if out_type(expected_forward): |
| grad = torch.randn_like(expected_forward) |
| expected_forward.backward(grad) |
| forward_with_mathview.backward( |
| math_op_view(math_op_physical(grad)) |
| ) |
| |
| self.assertEqual(tensor.grad, cloned1_tensor.grad) |
| |
| @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) |
| def test_conj_view(self, device, dtype, op): |
| if not op.test_conjugated_samples: |
| self.skipTest("Operation doesn't support conjugated inputs.") |
| math_op_physical = torch.conj_physical |
| math_op_view = torch.conj |
| _requires_grad = torch.cfloat in op.supported_backward_dtypes( |
| torch.device(device).type |
| ) |
| is_bit_set = torch.is_conj |
| samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) |
| self._test_math_view( |
| device, |
| dtype, |
| op, |
| samples, |
| math_op_physical, |
| math_op_view, |
| is_bit_set, |
| torch.is_complex, |
| ) |
| |
| @ops(ops_and_refs, allowed_dtypes=(torch.double,)) |
| def test_neg_view(self, device, dtype, op): |
| if not op.test_neg_view: |
| self.skipTest("Operation not tested with tensors with negative bit.") |
| math_op_physical = torch.neg |
| math_op_view = torch._neg_view |
| is_bit_set = torch.is_neg |
| samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) |
| self._test_math_view( |
| device, |
| dtype, |
| op, |
| samples, |
| math_op_physical, |
| math_op_view, |
| is_bit_set, |
| lambda x: True, |
| ) |
| |
| @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) |
| def test_neg_conj_view(self, device, dtype, op): |
| if not op.test_neg_view: |
| self.skipTest("Operation not tested with tensors with negative bit.") |
| if not op.test_conjugated_samples: |
| self.skipTest("Operation doesn't support conjugated inputs.") |
| |
| def math_op_physical(x): |
| return -x.conj_physical() |
| |
| def math_op_view(x): |
| return torch._neg_view(x).conj() |
| |
| def is_bit_set(x): |
| return torch.is_neg(x) and torch.is_conj(x) |
| |
| _requires_grad = dtype in op.supported_backward_dtypes( |
| torch.device(device).type |
| ) |
| samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) |
| # Only test one sample |
| samples = itertools.islice(samples, 1) |
| self._test_math_view( |
| device, |
| dtype, |
| op, |
| samples, |
| math_op_physical, |
| math_op_view, |
| is_bit_set, |
| torch.is_complex, |
| ) |
| |
| |
| # input strides and size may have been altered due to the result of an inplace op |
| def check_inplace_view(func, input, rs, input_size, input_strides): |
| if func is None: |
| return |
| # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out |
| # which mutate not necessarily the first input. |
| if isinstance(rs, torch.Tensor) and rs is input: |
| unequal_size = rs.size() != input_size |
| unequal_strides = rs.stride() != input_strides |
| # resize_ should probably have inplace_view tag. Not adding the tag since it |
| # breaks some codegen logic |
| if unequal_size or unequal_strides: |
| if isinstance(func, torch._ops.OpOverloadPacket): |
| func = func.default |
| # Reference: https://github.com/pytorch/pytorch/issues/78759 |
| if func is not torch.ops.aten.resize_.default: |
| # TODO: use self.assertIn when we have separate tests for each tag |
| assert torch.Tag.inplace_view in func.tags |
| |
| |
| # A mode that when enabled runs correctness checks to ensure |
| # that operators have expected tags based on their input and |
| # output tensor properties |
| class TestTagsMode(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| if isinstance(args[0], torch.Tensor): |
| old_size = args[0].size() |
| old_stride = args[0].stride() |
| rs = func(*args, **kwargs) |
| check_inplace_view(func, args[0], rs, old_size, old_stride) |
| else: |
| rs = func(*args, **kwargs) |
| return rs |
| |
| |
| # Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` |
| @unMarkDynamoStrictTest |
| class TestTags(TestCase): |
| @onlyCPU |
| @ops(ops_and_refs, dtypes=OpDTypes.any_one) |
| def test_tags(self, device, dtype, op): |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| for sample in samples: |
| # TODO: Test tags for ops that return a list of tensors |
| input = sample.input |
| if isinstance(input, torch.Tensor): |
| old_size = input.size() |
| old_stride = input.stride() |
| with TestTagsMode(): |
| rs = op(input, *sample.args, **sample.kwargs) |
| # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 |
| aten_name = op.aten_name if op.aten_name is not None else op.name |
| opoverloadpacket = getattr(torch.ops.aten, aten_name, None) |
| check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) |
| |
| |
| class TestSelfKwarg(TestCase): |
| def test_self_kwargs(self): |
| """Verify that we can call the aten ops with all kwargs even if the |
| argument's name is "self" |
| """ |
| torch.ops.aten.reshape.default(self=torch.rand(1, 2), shape=[2]) |
| torch.ops.aten.min.default(self=torch.rand(100)) |
| |
| |
| @unMarkDynamoStrictTest |
| class TestRefsOpsInfo(TestCase): |
| import_paths = [ |
| "_refs", |
| "_refs.special", |
| "_refs.nn.functional", |
| "_refs.fft", |
| "_refs._conversions", |
| ] |
| module_alls = [ |
| (path, import_module(f"torch.{path}").__all__) for path in import_paths |
| ] |
| ref_ops_names = tuple( |
| itertools.chain.from_iterable( |
| [f"{path}.{op}" for op in module_all] for path, module_all in module_alls |
| ) |
| ) |
| ref_db_names = {ref_op.name for ref_op in python_ref_db} |
| |
| # TODO: References that do not have an entry in python_ref_db |
| skip_ref_ops = { |
| "_refs.alias", |
| "_refs.bitwise_right_shift", |
| "_refs.copy_to", |
| "_refs.empty_permuted", |
| "_refs.empty_strided", |
| "_refs.equal", |
| "_refs.full", |
| "_refs.full_like", |
| "_refs.is_complex", |
| "_refs.to", |
| "_refs.mvlgamma", |
| "_refs.ones", |
| "_refs.ones_like", |
| "_refs.special.expit", |
| "_refs.std_var", |
| "_refs.swap_axes", |
| "_refs.uniform", |
| "_refs.scalar_tensor", |
| "_refs.trunc_divide", |
| "_refs.zero", |
| "_refs.zeros", |
| "_refs.zeros_like", |
| "_refs.rfloordiv", |
| "_refs.rtruediv", |
| "_refs.rpow", |
| # These should be tested with their out-of-place counterparts |
| "_refs.index_add_", |
| "_refs.index_copy_", |
| "_refs.index_fill_", |
| "_refs.native_group_norm", |
| } |
| |
| not_in_decomp_table = { |
| # duplicated in _decomp and _refs |
| "_refs.nn.functional.group_norm", |
| "_refs.nn.functional.mse_loss", |
| "_refs.floor_divide", |
| # duplicated as refs do not have decent support for advanced indexing |
| "_refs.index_copy", |
| "_refs.index_copy_", |
| "_refs.index_add", |
| "_refs.index_add_", |
| # these are not aten ops? |
| "_refs._conversions.bfloat16", |
| "_refs._conversions.bool", |
| "_refs._conversions.byte", |
| "_refs._conversions.char", |
| "_refs._conversions.double", |
| "_refs._conversions.float", |
| "_refs._conversions.half", |
| "_refs._conversions.int", |
| "_refs._conversions.long", |
| "_refs._conversions.short", |
| "_refs._conversions.chalf", |
| "_refs._conversions.cfloat", |
| "_refs._conversions.cdouble", |
| "_refs.broadcast_shapes", |
| "_refs.broadcast_tensors", |
| "_refs.mvlgamma", |
| "_refs.nn.functional.layer_norm", |
| "_refs.nn.functional.tanhshrink", |
| "_refs.nn.functional.triplet_margin_loss", |
| "_refs.rfloordiv", |
| "_refs.rtruediv", |
| "_refs.rpow", |
| # CompositeImplicitAutograd |
| "_refs.allclose", |
| "_refs.atleast_1d", |
| "_refs.atleast_2d", |
| "_refs.atleast_3d", |
| "_refs.broadcast_to", |
| "_refs.chunk", |
| "_refs.column_stack", |
| "_refs.contiguous", |
| "_refs.dsplit", |
| "_refs.dstack", |
| "_refs.fill", |
| "_refs.fill_", |
| "_refs.flatten", |
| "_refs.fliplr", |
| "_refs.flipud", |
| "_refs.float_power", |
| "_refs.hsplit", |
| "_refs.hstack", |
| "_refs.isclose", |
| "_refs.isfinite", |
| "_refs.isreal", |
| "_refs.istft", |
| "_refs.log_softmax", |
| "_refs.movedim", |
| "_refs.narrow", |
| "_refs.nn.functional.dropout", |
| "_refs.nn.functional.l1_loss", |
| "_refs.nn.functional.smooth_l1_loss", |
| "_refs.nn.functional.log_softmax", |
| "_refs.nn.functional.poisson_nll_loss", |
| "_refs.nn.functional.softmax", |
| "_refs.nn.functional.softmin", |
| "_refs.positive", |
| "_refs.ravel", |
| "_refs.reshape", |
| "_refs.softmax", |
| "_refs.special.expit", |
| "_refs.special.log_softmax", |
| "_refs.special.softmax", |
| "_refs.square", |
| "_refs.stft", |
| "_refs.T", |
| "_refs.take_along_dim", |
| "_refs.tensor_split", |
| "_refs.to", |
| "_refs.true_divide", |
| "_refs.trunc_divide", |
| "_refs.vsplit", |
| "_refs.vstack", |
| "_refs.linalg.matrix_norm", |
| "_refs.linalg.norm", |
| "_refs.linalg.svd", |
| "_refs.linalg.svdvals", |
| "_refs.unflatten", |
| "_refs.sum_to_size", |
| # ref implementation missing kwargs |
| "_refs.full_like", # missing "layout" |
| "_refs.scalar_tensor", # missing "layout" |
| # other |
| "_refs.block_diag", # only refs._block_diag_iterable is in decomposition table |
| "_refs.empty", # intentional; direct empty is faster and has less guards |
| "_refs.empty_permuted", # intentional; direct empty is faster and has less guards |
| "_refs.expand_as", |
| "_refs.as_strided", # _prims._as_strided_meta: "reduce() of empty sequence with no initial value" |
| "_refs.copy_to", # torch._C._jit_get_operation: No such operator aten::copy_to |
| "_refs.equal", # 'bool' object has no attribute 'dtype' |
| "_refs.conj", # Calls _prims.conj |
| "_refs.real", |
| "_refs.imag", |
| "_refs.reshape_as", |
| "_refs.view_as", |
| "_refs.view_as_complex", # TorchInductor does not support complex at the moment. |
| # the decompositions for these ops are slightly different |
| # because of out handling |
| "_refs.var_mean", |
| "_refs.std_mean", |
| "_refs.native_layer_norm", |
| } |
| |
| @parametrize("op", ref_ops_names) |
| def test_refs_are_in_python_ref_db(self, op): |
| inplace = op[-1] == "_" |
| if op in self.skip_ref_ops: |
| raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db") |
| elif inplace: |
| self.assertNotIn( |
| op, |
| self.ref_db_names, |
| msg=f"{op} is an in-place operation and should not have an OpInfo", |
| ) |
| else: |
| # Intentionally don't use assertIn to avoid printing the |
| # (very large) container |
| self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names") |
| |
| @parametrize("op", ref_ops_names) |
| def test_refs_are_in_decomp_table(self, op): |
| path = op.split(".") |
| module_path = ".".join(path[:-1]) |
| op_name = path[-1] |
| op_impl = getattr(import_module(f"torch.{module_path}"), op_name) |
| |
| if op in self.not_in_decomp_table: |
| self.assertNotIn( |
| op_impl, |
| torch._decomp.decomposition_table.values(), |
| f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()", |
| ) |
| else: |
| self.assertIn( |
| op_impl, |
| torch._decomp.decomposition_table.values(), |
| f"Did not find {op} in torch._decomp.decomposition_table.values()", |
| ) |
| |
| |
| fake_skips = ( |
| "aminmax", # failing input |
| "cov", # aweights cannot be negtaive |
| "istft", # window overlap add min: 0 |
| "linalg.eigvals", # The tensor has a non-zero number of elements, but its data is not allocated yet |
| "linalg.eigvalsh", # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend |
| "linalg.matrix_power", # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend |
| # "linalg.pinv", # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen |
| "linalg.matrix_rank.hermitian", # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend |
| "linalg.pinv.hermitian", # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor |
| "linalg.solve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend |
| "linalg.tensorsolve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' |
| "lu_solve", # MALLOC ERROR: debug |
| "multinomial", # Could not run 'aten::multinomial' with arguments from the 'Meta' backend |
| "mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend |
| "mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend |
| "mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend |
| "nanmean", # logical_not() got an unexpected keyword argument 'out' |
| "quantile", # quantile() q values must be in the range [0, 1] |
| "nanquantile", # quantile() q values must be in the range [0, 1] |
| "nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet |
| "nn.functional.embedding_bag", # sometimes errors |
| "nn.functional.nll_loss", # sometimes errors |
| "nn.functional.max_pool1d", # The tensor has a non-zero number of elements |
| "to_sparse", # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend |
| "tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet |
| "repeat_interleave", # cannot repeat_interleave a meta tensor without output_size |
| "sparse.sampled.addmm", # sparsity not supported |
| # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException |
| "nn.functional.one_hot", |
| "narrow", # Fails only for one overload with DataDependentOutputException (hence skip). |
| ) |
| |
| fake_autocast_device_skips = defaultdict(dict) |
| |
| # TODO: investigate/fix |
| fake_autocast_device_skips["cpu"] = {"linalg.pinv"} |
| fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"} |
| |
| |
| dynamic_output_op_tests = ( |
| "argwhere", |
| "bincount", |
| "combinations", |
| "linalg.lstsq", |
| "masked_select", |
| "nonzero", |
| "unique_consecutive", |
| "unique", |
| "linalg.lstsq.grad_oriented", |
| ) |
| |
| # Ops that have dynamic output shapes that we can handle when |
| # allow_dynamic_shape_ops is True in fake tensor shape environment. |
| supported_dynamic_output_op_tests = ( |
| "nonzero", |
| "unique", |
| "repeat_interleave", |
| "masked_select", |
| ) |
| |
| # some inputs invoke dynamic output shape operators, some do not |
| sometimes_dynamic_output_op_test = ("__getitem__", "index_select") |
| |
| data_dependent_op_tests = ( |
| "equal", |
| "corrcoef", |
| "nn.functional.gaussian_nll_loss", |
| "allclose", |
| ) |
| |
| aliasing_failures = ("histogramdd",) |
| |
| fake_backward_skips = { |
| "linalg.cond", |
| "linalg.matrix_norm", |
| "linalg.norm", |
| "linalg.svd", |
| "linalg.svdvals", |
| "pca_lowrank", |
| "roll", |
| "svd_lowrank", |
| "sgn", |
| } |
| |
| fake_backward_xfails = {skip(s) for s in fake_backward_skips} | { |
| xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default |
| xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default |
| skip("nn.functional.ctc_loss"), |
| } |
| |
| fake_autocast_backward_xfails = { |
| skip("nn.functional.binary_cross_entropy"), |
| skip("sparse.sampled_addmm"), |
| skip("linalg.pinv"), |
| skip("linalg.pinv", "hermitian"), |
| skip("linalg.pinv", "singular"), |
| skip("pinverse"), |
| } |
| |
| |
| @unMarkDynamoStrictTest |
| class TestFakeTensor(TestCase): |
| def setUp(self): |
| # Turn on FakeTensor caching and cross-checking for these tests: |
| cache_enabled = unittest.mock.patch( |
| "torch._dynamo.config.fake_tensor_cache_enabled", True |
| ) |
| cache_enabled.start() |
| self.addCleanup(cache_enabled.stop) |
| |
| cache_crosscheck = unittest.mock.patch( |
| "torch._dynamo.config.fake_tensor_cache_crosscheck_enabled", True |
| ) |
| cache_crosscheck.start() |
| self.addCleanup(cache_crosscheck.stop) |
| |
| def _test_fake_helper(self, device, dtype, op, context): |
| name = op.name |
| if op.variant_test_name: |
| name += "." + op.variant_test_name |
| if name in fake_skips or "sparse" in name or "jiterator" in name: |
| self.skipTest("Skip failing test") |
| |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| for sample in samples: |
| mode = FakeTensorMode() |
| |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv |
| |
| allow_dynamic_output_shape_shape_env = ShapeEnv( |
| allow_dynamic_output_shape_ops=True |
| ) |
| |
| allow_dynamic_output_shape_mode = FakeTensorMode( |
| shape_env=allow_dynamic_output_shape_shape_env |
| ) |
| |
| try: |
| with context(): |
| res = op(sample.input, *sample.args, **sample.kwargs) |
| except Exception: |
| continue |
| |
| def run_with_fake_mode_and_verify(fake_mode, match_results=True): |
| def map_to_fake(e): |
| if isinstance(e, torch.Tensor): |
| return fake_mode.from_tensor(e) |
| else: |
| return e |
| |
| input = tree_map(map_to_fake, sample.input) |
| args = tree_map(map_to_fake, sample.args) |
| kwargs = tree_map(map_to_fake, sample.kwargs) |
| |
| try: |
| with context(): |
| with fake_mode: |
| res_fake = op(input, *args, **kwargs) |
| |
| if not match_results: |
| return |
| |
| for fake_out, real_out in zip( |
| pytree.tree_leaves(res_fake), pytree.tree_leaves(res) |
| ): |
| if not isinstance(fake_out, torch.Tensor): |
| self.assertTrue(not isinstance(real_out, torch.Tensor)) |
| self.assertEqual(fake_out, real_out) |
| continue |
| |
| self.assertTrue(isinstance(fake_out, FakeTensor)) |
| # if you see a shape exception here, you may need to add |
| # a `dynamic_output_shape` tag to an operator |
| |
| if op.op not in [ |
| torch.ops.aten._efficient_attention_forward, |
| torch.ops.aten._flash_attention_forward, |
| ]: |
| # prims/decomps must correctly model strides, |
| # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 |
| |
| # note: the excluded ops have intentionally incorrect device; |
| # see "Note [Seed and Offset]" (_meta_registrations.py) |
| prims.utils.compare_tensor_meta(fake_out, real_out, True) |
| |
| if name not in aliasing_failures: |
| fake_aliasing = outputs_alias_inputs( |
| (input, args, kwargs), res_fake |
| ) |
| real_aliasing = outputs_alias_inputs( |
| (sample.input, sample, args, sample.kwargs), res |
| ) |
| self.assertEqual(fake_aliasing, real_aliasing) |
| |
| self.assertTrue( |
| name not in dynamic_output_op_tests |
| and name not in data_dependent_op_tests |
| ) |
| |
| except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: |
| pass |
| except torch._subclasses.fake_tensor.UnsupportedOperatorException: |
| pass |
| except torch._subclasses.fake_tensor.DynamicOutputShapeException: |
| self.assertTrue( |
| name in dynamic_output_op_tests |
| or name in sometimes_dynamic_output_op_test |
| ) |
| self.assertTrue( |
| fake_mode.shape_env is None |
| or not fake_mode.shape_env.allow_dynamic_output_shape_ops |
| or name not in supported_dynamic_output_op_tests |
| ) |
| except torch._subclasses.fake_tensor.DataDependentOutputException: |
| self.assertTrue(name in data_dependent_op_tests) |
| |
| run_with_fake_mode_and_verify(mode) |
| if name in supported_dynamic_output_op_tests: |
| run_with_fake_mode_and_verify( |
| allow_dynamic_output_shape_mode, match_results=False |
| ) |
| |
| @ops(op_db, dtypes=OpDTypes.any_one) |
| def test_pointwise_ops(self, device, dtype, op): |
| name = op.name |
| if op.variant_test_name: |
| name += "." + op.variant_test_name |
| if name in fake_skips or "sparse" in name or "jiterator" in name: |
| self.skipTest("Skip failing test") |
| |
| test_self = self |
| |
| class TestPointwiseMode(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs or {} |
| |
| out = func(*args, **kwargs) |
| |
| if torch.Tag.pointwise in func.tags: |
| shapes = [] |
| for inp in pytree.arg_tree_leaves(*args, **kwargs): |
| if isinstance(inp, torch.Tensor): |
| shapes.append(inp.shape) |
| |
| out_shape = torch._refs._broadcast_shapes(*shapes) |
| |
| for out_elem in pytree.tree_leaves(out): |
| if isinstance(out_elem, torch.Tensor): |
| test_self.assertEqual(out_elem.shape, out_shape) |
| |
| return out |
| |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| for sample in samples: |
| mode = FakeTensorMode() |
| |
| def map_to_fake(e): |
| if isinstance(e, torch.Tensor): |
| return mode.from_tensor(e) |
| else: |
| return e |
| |
| input = tree_map(map_to_fake, sample.input) |
| args = tree_map(map_to_fake, sample.args) |
| kwargs = tree_map(map_to_fake, sample.kwargs) |
| |
| try: |
| op(input, *args, **kwargs) |
| except Exception as e: |
| continue |
| |
| with TestPointwiseMode(): |
| with mode: |
| op(input, *args, **kwargs) |
| |
| @ops(op_db, dtypes=OpDTypes.any_one) |
| def test_fake(self, device, dtype, op): |
| self._test_fake_helper(device, dtype, op, contextlib.nullcontext) |
| |
| @ops(op_db, dtypes=OpDTypes.any_one) |
| def test_fake_autocast(self, device, dtype, op): |
| device_type = torch.device(device).type |
| if op.name in fake_autocast_device_skips[device_type]: |
| self.skipTest("Skip failing test") |
| |
| def context_fn(): |
| return torch.amp.autocast(device_type) |
| |
| self._test_fake_helper(device, dtype, op, context_fn) |
| |
| def _test_fake_crossref_helper(self, device, dtype, op, context): |
| samples = op.sample_inputs(device, dtype, requires_grad=True) |
| |
| for iter, sample in enumerate(samples): |
| args = [sample.input] + list(sample.args) |
| kwargs = sample.kwargs |
| |
| # skip these to speed up tests |
| common_skip_ops = ( |
| aten.detach.default, |
| aten.empty_strided.default, |
| aten.copy_.default, |
| aten.is_same_size.default, |
| ) |
| |
| # TODO: enable check_aliasing, batch norm fails |
| try: |
| with torch._subclasses.CrossRefFakeMode( |
| ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True |
| ): |
| with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled( |
| False |
| ): |
| composite_compliance.compute_expected_grads( |
| op.get_op(), |
| args, |
| kwargs, |
| sample.output_process_fn_grad, |
| op.gradcheck_wrapper, |
| ) |
| except torch._subclasses.fake_tensor.UnsupportedOperatorException: |
| pass |
| |
| @onlyCUDA |
| @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) |
| @skipOps( |
| "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails |
| ) |
| def test_fake_crossref_backward_no_amp(self, device, dtype, op): |
| self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) |
| |
| @onlyCUDA |
| @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) |
| @skipOps( |
| "TestFakeTensor", |
| "test_fake_crossref_backward_amp", |
| fake_backward_xfails | fake_autocast_backward_xfails, |
| ) |
| def test_fake_crossref_backward_amp(self, device, dtype, op): |
| self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast) |
| |
| @ops([op for op in ops_and_refs if op.is_factory_function]) |
| def test_strided_layout(self, device, dtype, op): |
| samples = op.sample_inputs(device, dtype) |
| for sample in samples: |
| kwargs = sample.kwargs.copy() |
| kwargs["layout"] = torch.strided |
| strided_result = op(sample.input, *sample.args, **kwargs) |
| self.assertEqual(strided_result.layout, torch.strided) |
| |
| |
| instantiate_device_type_tests(TestCommon, globals()) |
| instantiate_device_type_tests(TestCompositeCompliance, globals()) |
| instantiate_device_type_tests(TestMathBits, globals()) |
| instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") |
| instantiate_device_type_tests(TestFakeTensor, globals()) |
| instantiate_device_type_tests(TestTags, globals()) |
| |
| if __name__ == "__main__": |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |