| # Owner(s): ["module: decompositions"] |
| |
| import functools |
| import itertools |
| import re |
| import unittest |
| from collections import defaultdict |
| from functools import partial |
| |
| import torch._inductor.decomposition |
| import torch.autograd |
| from torch import Tensor |
| from torch._decomp import core_aten_decompositions, decomposition_table |
| from torch._dispatch.python import enable_python_dispatcher |
| from torch._ops import DispatchKey |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_cuda import tf32_off |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCPU, |
| onlyCUDA, |
| onlyNativeDeviceTypes, |
| ops, |
| ) |
| from torch.testing._internal.common_methods_invocations import ( |
| op_db, |
| skip, |
| skipOps, |
| xfail, |
| ) |
| from torch.testing._internal.common_modules import module_db, modules |
| from torch.testing._internal.common_utils import ( |
| is_iterable_of_tensors, |
| run_tests, |
| skipIfCrossRef, |
| skipIfTorchDynamo, |
| suppress_warnings, |
| TEST_WITH_ASAN, |
| TEST_WITH_SLOW, |
| TestCase, |
| unMarkDynamoStrictTest, |
| ) |
| from torch.utils import _pytree as pytree |
| from torch.utils._python_dispatch import TorchDispatchMode |
| from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten |
| |
| |
| aten = torch.ops.aten |
| |
| |
| # TODO: this isn't going to work with non-aten namespaces |
| def overload_to_aten_name(op): |
| return op._schema.name.split("::")[1] |
| |
| |
| # All operators that can have decomp tests |
| decomposition_names = { |
| overload_to_aten_name(k) |
| for k in decomposition_table |
| if isinstance(k, torch._ops.OpOverload) |
| } |
| core_decomposition_names = { |
| overload_to_aten_name(k) |
| for k in core_aten_decompositions() |
| if isinstance(k, torch._ops.OpOverload) |
| } |
| _decomp_test_ops = [ |
| op |
| for op in op_db |
| if op.aten_name in decomposition_names |
| or op.aten_backward_name in decomposition_names |
| ] |
| _decomp_test_ops_core_autograd = [ |
| op |
| for op in op_db |
| if op.aten_name in core_decomposition_names and op.supports_autograd |
| ] |
| _sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] |
| |
| |
| def diff_arg(arg, requires_grad=True): |
| def is_differentiable_arg(arg): |
| if requires_grad: |
| return arg.requires_grad |
| else: |
| return arg.is_floating_point() or arg.is_complex() |
| |
| if is_iterable_of_tensors(arg): |
| if all(is_differentiable_arg(a) for a in arg): |
| return True |
| if all(not is_differentiable_arg(a) for a in arg): |
| return False |
| raise RuntimeError("NYI: The test runner can't handle this") |
| return isinstance(arg, Tensor) and is_differentiable_arg(arg) |
| |
| |
| # Version of autograd.grad with some differences: |
| # - pytree inputs is allowed (but leaves of the pytree have to all |
| # be tensors) |
| # - if an input is not used as part of derivatives, we will return a |
| # zero-filled tensor for the result |
| def _autograd_grad( |
| outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True |
| ): |
| inputs, inputs_spec = tree_flatten(inputs) |
| diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) |
| if grad_outputs is None: |
| diff_outputs = tuple(out for out in outputs if out.requires_grad) |
| else: |
| diff_grad_outputs = [ |
| (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad |
| ] |
| if len(diff_grad_outputs) == 0: |
| diff_outputs, grad_outputs = (), () |
| else: |
| diff_outputs, grad_outputs = zip(*diff_grad_outputs) |
| grad_inputs = torch.autograd.grad( |
| diff_outputs, |
| diff_inputs, |
| grad_outputs, |
| retain_graph=retain_graph, |
| create_graph=create_graph, |
| allow_unused=True, |
| ) |
| result = [] |
| grad_inputs_iter = iter(grad_inputs) |
| for inp in inputs: |
| if inp.requires_grad: |
| grad_input = next(grad_inputs_iter) |
| if grad_input is None: |
| result.append(torch.zeros_like(inp)) |
| else: |
| result.append(grad_input) |
| else: |
| result.append(torch.zeros_like(inp)) |
| return tree_unflatten(result, inputs_spec) |
| |
| |
| def _as_tuple(val): |
| if isinstance(val, tuple): |
| return val |
| return (val,) |
| |
| |
| def ref_vjp_no_create(f, *primals): |
| result = f(*primals) |
| |
| def wrapped(cotangents): |
| return _autograd_grad( |
| _as_tuple(result), |
| primals, |
| _as_tuple(cotangents), |
| create_graph=False, |
| retain_graph=True, |
| ) |
| |
| return result, wrapped |
| |
| |
| dtype_precisions = { |
| torch.float16: (0.001, 1e-5), |
| torch.bfloat16: (0.016, 1e-4), |
| torch.float32: (1.3e-6, 1e-5), |
| torch.float64: (1e-7, 1e-7), |
| torch.complex32: (0.001, 1e-5), |
| torch.complex64: (1.3e-6, 1e-5), |
| torch.complex128: (1e-7, 1e-7), |
| } |
| # Returns the "default" rtol and atol for comparing scalars or |
| # tensors of the given dtypes. |
| |
| |
| def _getDefaultRtolAndAtol(dtype0, dtype1): |
| rtol = max( |
| dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] |
| ) |
| atol = max( |
| dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] |
| ) |
| return rtol, atol |
| |
| |
| def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): |
| assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" |
| if orig.numel() == 0 or decomp.numel() == 0: |
| assert orig.numel() == decomp.numel() |
| return |
| assert orig.shape == decomp.shape, f"{i} Operation: {op}" |
| tol_table = { |
| (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, |
| (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, |
| (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3, |
| (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, |
| (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, |
| (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, |
| (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, |
| (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, |
| (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, |
| (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, |
| (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, |
| (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, |
| (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, |
| (torch.float16, torch.ops.aten.var_mean.correction): 5e-7, |
| (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7, |
| (torch.float16, torch.ops.aten.var_mean.dim): 5e-7, |
| (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, |
| (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, |
| (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, |
| (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, |
| (torch.float16, torch.ops.aten.hardswish.default): 2e-7, |
| (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, |
| (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, |
| (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2, |
| (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, |
| (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, |
| (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, |
| (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, |
| (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, |
| (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, |
| (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, |
| (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, |
| # see https://github.com/pytorch/pytorch/pull/96264 |
| (torch.float16, torch.ops.aten.mv.default): 1e-5, |
| (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, |
| (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, |
| (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, |
| } |
| if ref.is_floating_point(): |
| orig_diff = (orig - ref).abs().max() |
| decomp_diff = (decomp - ref).abs().max() |
| atol = tol_table.get((test_dtype, op), 1e-7) |
| if decomp_diff > orig_diff + atol: |
| raise RuntimeError( |
| f"Difference from float64 is larger with decomposition {op.__name__}" |
| f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" |
| f"atol = {atol}\n" |
| f"args = {args}\n" |
| f"kwargs = {kwargs}" |
| ) |
| else: |
| test_case.assertEqual( |
| orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" |
| ) |
| |
| |
| def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): |
| test_case.assertEqual( |
| orig.dtype, |
| decomp.dtype, |
| f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", |
| ) |
| # Before adding an entry to this table, make sure your decomposition is right :) |
| tol_table = { |
| # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 |
| (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), |
| (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( |
| 1e-3, |
| 1e-3, |
| ), |
| (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), |
| # This exceeds default tolerances only on CPU, on CUDA it's fine |
| (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), |
| # Exceeds tolerances on CUDA, likely due to fma |
| (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), |
| (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), |
| (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), |
| (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), |
| # The decomposition is TOO correct. It computes everything in int64, so sometimes |
| # there's an off-by-one error. See |
| # https://github.com/pytorch/pytorch/issues/81996 |
| # https://github.com/pytorch/pytorch/issues/82230 |
| (torch.int8, torch.ops.aten.linspace.default): (0, 1), |
| (torch.uint8, torch.ops.aten.linspace.default): (0, 1), |
| (torch.int16, torch.ops.aten.linspace.default): (0, 1), |
| (torch.int32, torch.ops.aten.linspace.default): (0, 1), |
| (torch.int64, torch.ops.aten.linspace.default): (0, 1), |
| (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), |
| (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), |
| (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), |
| (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), |
| (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), |
| (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), |
| (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), |
| (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), |
| (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), |
| (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), |
| (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), |
| (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), |
| (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), |
| (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), |
| (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), |
| } |
| if (decomp.dtype, op) in tol_table: |
| rtol, atol = tol_table[(decomp.dtype, op)] |
| else: |
| rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) |
| test_case.assertEqual( |
| orig, |
| decomp, |
| rtol=rtol, |
| atol=atol, |
| msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", |
| ) |
| |
| |
| # Given f, returns an f' such that: |
| # - f' takes only positional arguments |
| # - All arguments to f' are floating-point Tensors |
| # - All outputs of f' are floating-point Tensors |
| def normalize_op_input_output2( |
| f, args, kwargs, output_process_fn_grad=None, requires_grad=True |
| ): |
| flat_args, args_spec = tree_flatten(args) |
| diff_argnums = tuple( |
| i |
| for i, arg in enumerate(flat_args) |
| if diff_arg(arg, requires_grad=requires_grad) |
| ) |
| assert len(diff_argnums) > 0 |
| primals = tuple(flat_args[i] for i in diff_argnums) |
| |
| @functools.wraps(f) |
| def wrapped(*primals): |
| _args = list(flat_args) |
| for num, arg in zip(diff_argnums, primals): |
| _args[num] = arg |
| _args = tree_unflatten(_args, args_spec) |
| result = f(*_args, **kwargs) |
| if output_process_fn_grad is not None: |
| result = output_process_fn_grad(result) |
| if isinstance(result, tuple): |
| # TODO We should check that the integer outputs also agree |
| result = tuple( |
| r |
| for r in result |
| if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) |
| ) |
| assert len(result) > 0 |
| return result |
| |
| return wrapped, primals |
| |
| |
| # NB: This also upcasts dtype arguments |
| # TODO: handle complex correctly |
| def upcast_tensor(x, dtype=torch.float32): |
| if isinstance(x, Tensor) and x.dtype.is_floating_point: |
| return x.to(dtype=dtype) |
| elif isinstance(x, torch.dtype) and x in [ |
| torch.float16, |
| torch.bfloat16, |
| torch.float, |
| ]: |
| return dtype |
| else: |
| return x |
| |
| |
| def normalize_op_input_output(f, sample, requires_grad=True): |
| args = tuple([sample.input] + list(sample.args)) |
| return normalize_op_input_output2( |
| f, |
| args, |
| sample.kwargs, |
| sample.output_process_fn_grad, |
| requires_grad=requires_grad, |
| ) |
| |
| |
| CROSS_REF_EXCLUDE_SET = { |
| # CUBLAS_STATUS_NOT_SUPPORTED when calling |
| # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, |
| # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, |
| # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, |
| # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` |
| ("cuda", torch.bfloat16, "nn.functional.bilinear"), |
| # randomness |
| (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed |
| (None, None, "new_empty"), |
| (None, None, "empty_like"), |
| (None, None, "empty"), |
| # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. |
| (None, None, "item"), |
| # It's the only in-place op without an out-of-place equivalent in the Python API |
| # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. |
| (None, None, "zero_"), |
| # No idea what's going on here |
| # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) |
| # in the test, but it seems to pass when tested locally and in the logsumexp test |
| (None, torch.float32, "masked.logsumexp"), |
| (None, torch.float64, "masked.logsumexp"), |
| # exp_vml_cpu not implemented for Half |
| (torch.cpu, torch.float16, "signal.windows.exponential"), |
| (torch.cpu, torch.float16, "signal.windows.gaussian"), |
| # sin_vml_cpu not implemented for Half |
| (torch.cpu, torch.float16, "signal.windows.cosine"), |
| # CompositeAutogradImplicit |
| # See https://github.com/pytorch/pytorch/issues/81669 |
| (None, None, "nn.functional.relu6"), |
| # This decomp runs before autograd. |
| (None, None, "nn.functional.rrelu"), |
| (None, None, "meshgrid"), |
| # Decomposition registered as Autograd |
| (None, None, "nn.functional.hardshrink"), |
| (None, None, "nn.functional.softshrink"), |
| # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) |
| (None, None, "diag"), |
| # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 |
| ("cpu", torch.bfloat16, "_softmax_backward_data"), |
| (None, None, "norm"), |
| # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) |
| (None, None, "native_batch_norm"), |
| (None, None, "_upsample_bilinear2d_aa"), |
| (None, None, "empty_strided"), # aten.empty_strided was not decomposed |
| } |
| |
| CROSS_REF_BACKWARD_EXCLUDE_SET = { |
| # Decomposed backward formula is not as precise |
| ("cpu", torch.bfloat16, "nn.functional.hardswish"), |
| ("cuda", torch.float16, "nn.functional.cross_entropy"), |
| } |
| |
| all_decomposed = set() |
| all_called = defaultdict(int) |
| |
| # Helpful snippet for testing coverage |
| """ |
| import atexit |
| def check_coverage(): |
| print("missing coverage:") |
| print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) |
| atexit.register(check_coverage) |
| """ |
| |
| # Helpful snippet for Horace to create his google sheet :) |
| """ |
| import atexit |
| def dump_ops(): |
| with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: |
| for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): |
| f.write(f'{op.__name__}\n') |
| g.write(f'{count}\n') |
| with open('run_decompositions.txt', 'w') as f: |
| for op in sorted([i.__name__ for i in all_decomposed]): |
| f.write(f'{op}\n') |
| |
| atexit.register(dump_ops) |
| """ |
| |
| |
| def any_unsupported(args, kwargs): |
| def test_unsupported(t): |
| if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: |
| # These are all things that we haven't coded decompositions |
| # to handle correctly. Maybe they should. |
| return any( |
| [ |
| t.is_sparse_csr, |
| t.is_sparse, |
| t.is_mkldnn, |
| t.is_quantized, |
| t.is_nested, |
| torch._is_functional_tensor(t), |
| ] |
| ) |
| elif torch.overrides.is_tensor_like(t): |
| # Decompositions will generally change the behavior of Tensor-like |
| # subclasses, so bypass tests in this case too |
| return True |
| else: |
| return False |
| |
| flat_args = pytree.arg_tree_leaves(*args, **kwargs) |
| return any(test_unsupported(x) for x in flat_args) |
| |
| |
| core_backward_failures = { |
| skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs |
| xfail("addcdiv"), |
| skip("addcmul"), # slow: fails with --timeout=360 secs |
| skip("deg2rad"), # slow: fails with --timeout=360 secs |
| skip("diag_embed"), # slow: fails with --timeout=360 secs |
| skip("frac"), # slow: fails with --timeout=360 secs |
| skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs |
| xfail("lerp"), |
| skip("logaddexp"), # slow: fails with --timeout=360 secs |
| skip("native_dropout_backward"), # slow: fails with --timeout=360 secs |
| xfail("nn.functional.binary_cross_entropy_with_logits"), |
| skip("nn.functional.glu"), # slow: fails with --timeout=360 secs |
| xfail("nn.functional.hardshrink"), |
| xfail("nn.functional.softshrink"), |
| skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs |
| xfail("norm"), |
| xfail("norm", "fro"), |
| xfail("norm", "inf"), |
| xfail("norm", "nuc"), |
| skip("rad2deg"), # slow: fails with --timeout=360 secs |
| skip("renorm"), # slow: fails with --timeout=360 secs |
| skip("rot90"), # slow: fails with --timeout=360 secs |
| skip("rsub"), # slow: fails with --timeout=360 secs |
| skip("sgn"), # slow: fails with --timeout=360 secs |
| skip("special.xlog1py"), # slow: fails with --timeout=360 secs |
| xfail("stack"), |
| skip("tril"), # slow: fails with --timeout=360 secs |
| skip("triu"), # slow: fails with --timeout=360 secs |
| skip("unfold_copy"), # slow: fails with --timeout=360 secs |
| skip("xlogy"), # slow: fails with --timeout=360 secs |
| xfail("zero_"), |
| } |
| if not TEST_WITH_SLOW: |
| core_backward_failures.update( |
| { |
| skip("addr"), # slow: takes 46 sec on A100 |
| skip("baddbmm"), # slow: takes 800+ sec on A100 |
| skip("clamp_min"), # slow: takes 800 sec on A100 |
| skip("clamp_max"), # slow: takes 800 sec on A100 |
| skip("logit"), # slow: takes 44 sec on A100 |
| skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 |
| skip("std_mean"), # slow: takes 170 sec on A100 |
| skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 |
| skip("transpose"), # slow: takes 50 sec on A100 |
| skip("unbind"), # slow: takes 70 sec on A100 |
| skip("unsafe_split"), # slow: takes 49 sec on A100 |
| } |
| ) |
| |
| comprehensive_failures = { |
| xfail( |
| "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) |
| ), # off by one error |
| xfail( |
| "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) |
| ), # off by one error |
| xfail( |
| "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) |
| ), # off by one error |
| } |
| |
| |
| @unMarkDynamoStrictTest |
| class TestDecomp(TestCase): |
| longMessage = True |
| |
| # NB: This actually overlaps with test_comprehensive, but it only |
| # runs on things that are definitely decomposed so it's a lot faster |
| # to run |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| @suppress_warnings |
| @ops(_decomp_test_ops) |
| def test_quick(self, device, dtype, op): |
| self.do_cross_ref(device, dtype, op, run_all=False) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| @suppress_warnings |
| @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) |
| def test_quick_core_backward(self, device, dtype, op): |
| for sample_input in op.sample_inputs(device, dtype, requires_grad=True): |
| aten_name = op.decomp_aten_name or op.aten_name |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| func = partial(op.get_op(), **kwargs) |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all=False |
| ) as mode, enable_python_dispatcher(): |
| torch.autograd.gradcheck(func, args) |
| self.check_decomposed(aten_name, mode) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) |
| @suppress_warnings |
| @ops(op_db) |
| def test_comprehensive(self, device, dtype, op): |
| self.do_cross_ref(device, dtype, op, run_all=True) |
| |
| def test_uniform(self, device): |
| size = (2, 3, 4, 5) |
| dtype = torch.float32 |
| x = make_tensor(size, dtype=dtype, device=device) |
| low = 0.3 |
| high = 0.9 |
| |
| torch.manual_seed(123) |
| ref = torch.ops.aten.uniform(x, low, high) |
| torch.manual_seed(123) |
| res = torch._decomp.decompositions.uniform(x, low=low, high=high) |
| self.assertEqual(ref, res) |
| |
| def test_broadcasting_index_copy(self, device): |
| x = torch.zeros([1, 10], device=device) |
| xs = torch.ones([2, 10], device=device) |
| |
| def index_copy(xs, x): |
| torch._decomp.decompositions.index_copy_( |
| xs, 0, torch.tensor(0).to(device), x |
| ) |
| |
| index_copy(xs, x) |
| |
| xs_two = torch.ones([2, 10], device=device) |
| xs_two[0] = x |
| |
| self.assertEqual(xs, xs_two) |
| |
| def test_cat_single_input(self, device): |
| decomp_table = torch._inductor.decomposition.select_decomp_table() |
| cat_inductor = decomp_table[torch.ops.aten.cat.default] |
| |
| inp = torch.rand([2048, 2048], device=device) |
| inps = [inp for _ in range(10)] |
| |
| for dim in (-1, 0, 1): |
| self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) |
| |
| def test_rrelu_with_noise(self, device): |
| # rrelu_with_noise behavior depends on a) whether elements in the input |
| # are <= 0, and b) whether we're in training mode. Cover all cases: |
| dtype = torch.float64 |
| x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device) |
| lower = 1.0 |
| upper = 4.0 |
| training = False |
| |
| torch.manual_seed(123) |
| noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) |
| ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) |
| |
| torch.manual_seed(123) |
| noise_res = torch.zeros(x.shape, dtype=dtype, device=device) |
| res = torch._decomp.decompositions.rrelu_with_noise( |
| x, |
| noise_res, |
| lower, |
| upper, |
| training, |
| ) |
| self.assertEqual(ref, res) |
| self.assertEqual(noise_ref, noise_res) |
| |
| # Now with training=True: |
| training = True |
| |
| torch.manual_seed(123) |
| noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) |
| ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) |
| |
| torch.manual_seed(123) |
| noise_res = torch.zeros(x.shape, dtype=dtype, device=device) |
| res = torch._decomp.decompositions.rrelu_with_noise( |
| x, |
| noise_res, |
| lower, |
| upper, |
| training, |
| ) |
| self.assertEqual(ref, res) |
| self.assertEqual(noise_ref, noise_res) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @suppress_warnings |
| @tf32_off() |
| # only tests RNNs since we have py dispsatcher decomps for them |
| @modules( |
| filter( |
| lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), |
| module_db, |
| ) |
| ) |
| def test_rnn_decomp_module(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func( |
| module_info, |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| training=training, |
| ) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| args, kwargs = ( |
| module_input.constructor_input.args, |
| module_input.constructor_input.kwargs, |
| ) |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| |
| args, kwargs = ( |
| module_input.forward_input.args, |
| module_input.forward_input.kwargs, |
| ) |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all=True |
| ), enable_python_dispatcher(): |
| decomp_out = m(*args, **kwargs) |
| |
| non_decomp_out = m(*args, **kwargs) |
| # without this check, incorrect decomps at the python dispatcher level can still pass because |
| # they're checking aten decomps at the torch_dispatch level |
| self.assertEqual(decomp_out, non_decomp_out) |
| |
| def test_batch_norm_unflatten_weight_bias(self, device): |
| # https://github.com/pytorch/pytorch/issues/100970 |
| shape = (1, 3, 2, 2) |
| input = torch.randn(shape, device=device) |
| weight = torch.randn((3, 1, 1, 1), device=device) |
| bias = torch.randn(3, device=device) |
| mean = torch.randn(3, device=device) |
| var = torch.randn(3, device=device) |
| res = torch._decomp.decompositions.native_batch_norm( |
| input, weight, bias, mean, var, False, 1, 1e-05 |
| ) |
| self.assertEqual(shape, res[0].shape) |
| |
| def test_arange_graph(self, device): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| def func(x, start): |
| le = x.shape[-1] |
| if start is None: |
| a = torch.arange(le, dtype=torch.float32, device=x.device) |
| else: |
| a = torch.arange(start, le, dtype=torch.float32, device=x.device) |
| return a |
| |
| pattern = r", device = device\(.+\), requires_grad = False" |
| |
| cfunc = make_fx(func, decomposition_table=decomposition_table) |
| fx_g = cfunc(torch.rand(10, device=device), None) |
| fx_g_code = fx_g.code.strip() |
| # Remove device and requires_grad |
| fx_g_code = re.sub(pattern, "", fx_g_code) |
| self.assertExpectedInline( |
| fx_g_code, |
| """\ |
| def forward(self, x_1, start_1): |
| iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) |
| mul = torch.ops.prims.mul.default(iota, 1); iota = None |
| add = torch.ops.prims.add.default(mul, 0); mul = None |
| convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None |
| return convert_element_type""", |
| ) |
| |
| fx_g = cfunc(torch.rand(10, device=device), 1) |
| fx_g_code = fx_g.code.strip() |
| # Remove device and requires_grad |
| fx_g_code = re.sub(pattern, "", fx_g_code) |
| self.assertExpectedInline( |
| fx_g_code, |
| """\ |
| def forward(self, x_1, start_1): |
| iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) |
| mul = torch.ops.prims.mul.default(iota, 1); iota = None |
| add = torch.ops.prims.add.default(mul, 1); mul = None |
| convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None |
| return convert_element_type""", |
| ) |
| |
| def test_masked_fill(self, device): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| if torch.device(device).type not in [ |
| "xpu", |
| "cuda", |
| torch._C._get_privateuse1_backend_name(), |
| ]: |
| self.skipTest("only runs on XPU and CUDA and PrivateUse1.") |
| |
| def func(scores, mask, value): |
| return scores.masked_fill(mask, value) |
| |
| scores_t = torch.tensor([1, 2, 3, 4], device=device) |
| mask_t = torch.tensor([True, True, True, True], device=device) |
| value_t = torch.tensor(0, dtype=scores_t.dtype) |
| cfunc = make_fx(func, decomposition_table=decomposition_table) |
| fx_g = cfunc(scores_t, mask_t, value_t) |
| self.assertExpectedInline( |
| fx_g.code.strip(), |
| """\ |
| def forward(self, scores_1, mask_1, value_1): |
| where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None |
| return where""", |
| ) |
| |
| class DecompCrossRefMode(TorchDispatchMode): |
| def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): |
| self.test_case = test_case |
| self.saved_precision = saved_precision |
| self.saved_rel_tol = saved_rel_tol |
| self.test_dtype = dtype |
| self.run_all = run_all |
| |
| # We check the correctness of each decomposition right after running it. |
| # So, when we encounter a decomposition, we run the function normally, and |
| # then run the decomposition, and ensure they're identical. |
| self.called = set() |
| self.decomposed = set() |
| |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| self.test_case.precision = self.saved_precision |
| self.test_case.rel_tol = self.saved_rel_tol |
| |
| self.called.add(func) |
| all_called[func] += 1 |
| |
| # Stuff we shouldn't bother testing |
| # (TODO: remove detach from the decomp table?) |
| # N.b. Testing in-place ops would need dedicated logic |
| in_place = func.name()[-1] == "_" |
| ignored_ops = [ |
| torch.ops.aten.detach.default, |
| # non-deterministic ops |
| torch.ops.aten.empty.memory_format, |
| torch.ops.aten.empty_like.default, |
| torch.ops.aten.new_empty.default, |
| torch.ops.aten.empty_strided.default, |
| torch.ops.aten.new_empty_strided.default, |
| torch.ops.aten.randn.default, |
| torch.ops.aten.native_dropout.default, |
| ] |
| if ( |
| func not in decomposition_table |
| or func in ignored_ops |
| or torch.Tag.nondeterministic_seeded in func.tags |
| or any_unsupported(args, kwargs) |
| or in_place |
| ): |
| return func(*args, **kwargs) |
| |
| self.decomposed.add(func) |
| all_decomposed.add(func) |
| |
| # We take 2 main strategies for verifying correctness/numerical stability of decompositions |
| # The first one is simply tolerance checking between decomp_out and pytorch_out |
| # However, for fp16/bf16 and reductions, this becomes very |
| # finicky, as there are not many guarantees we can make. |
| # So, for fp16/bf16, we instead compare the difference of |
| # {decomp_out, pytorch_out_64} and {pytorch_out, |
| # pytorch_out_64}. In other words, we compare how far the |
| # decomposition and pytorch are from the "ground truth" (i.e. |
| # fp64). If the decomposition results in more error, we error |
| |
| # We also decompose the decomposition recursively for |
| # further coverage, as some paths not be exercised directly by |
| # OpInfos (sadly) but just by other ops |
| |
| decomposition = decomposition_table[func] |
| |
| do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] |
| if self.run_all: |
| # Execute recursively via DFS, to find the root of a possible error first |
| with self: |
| decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) |
| else: |
| decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) |
| |
| # At this stage we should not be decomposing an in-place op |
| # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops |
| # because decompositions are run after functionalisation and we would not like them to |
| # de-functionalise the graph, as that would break AoTAutograd |
| # We run the real function *after* the decomposition to make sure that the |
| # decomposition does not modify any of the inputs in-place. If it does |
| # real_out should be differen than decom_out so we should catch this |
| real_out_unflat = func(*args, **kwargs) |
| real_out = pytree.tree_leaves(real_out_unflat) |
| |
| assert len(real_out) == len(decomp_out) |
| |
| if do_relative_check: |
| upcast = partial(upcast_tensor, dtype=torch.float64) |
| real_out_double, _ = tree_flatten( |
| func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) |
| ) |
| for i, (orig, decomp, ref) in enumerate( |
| zip(real_out, decomp_out, real_out_double) |
| ): |
| if not isinstance(orig, torch.Tensor): |
| assert type(orig) == type(decomp) |
| assert orig == decomp |
| continue |
| op_assert_ref( |
| self.test_case, |
| func, |
| self.test_dtype, |
| i, |
| orig, |
| decomp, |
| ref, |
| args, |
| kwargs, |
| ) |
| else: |
| for orig, decomp in zip(real_out, decomp_out): |
| if not isinstance(orig, torch.Tensor): |
| assert type(orig) == type(decomp) |
| assert orig == decomp |
| continue |
| op_assert_equal( |
| self.test_case, |
| func, |
| self.test_dtype, |
| orig, |
| decomp, |
| args, |
| kwargs, |
| ) |
| |
| return real_out_unflat |
| |
| def check_decomposed(self, aten_name, mode): |
| self.assertTrue( |
| any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), |
| msg=( |
| f"aten.{aten_name} was not decomposed, saw calls for: " |
| f"{', '.join(map(str, list(mode.called)))}. If your op is " |
| f"CompositeImplicitAutograd you should skip this test " |
| f"by updating CROSS_REF_EXCLUDE_SET." |
| ), |
| ) |
| |
| @skipIfTorchDynamo("Test does not work with TorchDynamo") |
| def do_cross_ref(self, device, dtype, op, *, run_all): |
| test_keys = [ |
| (torch.device(device).type, dtype, op.name), |
| (None, dtype, op.name), |
| (None, None, op.name), |
| ] |
| if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): |
| self.skipTest(f"{op.name} in {dtype} not supported") |
| |
| skip_decomp_vjp = any( |
| key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys |
| ) |
| |
| requires_grad = ( |
| op.supports_autograd |
| and dtype in op.supported_backward_dtypes(torch.device(device).type) |
| # TODO: OpInfo really ought to error out for this case, but it's |
| # not exercised in test_ops_gradients atm. The problem is not |
| # complex32 per-se (which is supported by data movement only ops) |
| # but that when we do backwards we expect other ops like add to work |
| and not dtype == torch.complex32 |
| ) |
| samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) |
| |
| aten_name = op.decomp_aten_name or op.aten_name |
| |
| func = op.get_op() |
| |
| def run_without_python_dispatcher(mode): |
| return any( |
| isinstance(op, torch._ops.OpOverload) |
| and op.has_kernel_for_dispatch_key( |
| DispatchKey.CompositeImplicitAutograd |
| ) |
| for op in mode.decomposed.union([func]) |
| ) |
| |
| for sample_input in samples: |
| if requires_grad: |
| fn, primals = normalize_op_input_output(func, sample_input) |
| primals = tree_map( |
| lambda x: x if isinstance(x, torch.Tensor) else x, primals |
| ) |
| |
| # Once https://github.com/pytorch/pytorch/pull/75965/ I can |
| # store the called list on the mode object instance and no |
| # explicit clearing is necessary as I will create a fresh mode |
| # for each region |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode, enable_python_dispatcher(): |
| decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) |
| if run_without_python_dispatcher(mode): |
| # without this check, incorrect decomps at the python dispatcher level can still pass because |
| # they're checking aten decomps at the torch_dispatch level. |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode: |
| decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) |
| if aten_name in decomposition_names: |
| self.check_decomposed(aten_name, mode) |
| |
| if not skip_decomp_vjp and ( |
| op.aten_backward_name in decomposition_names or run_all |
| ): |
| cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) |
| |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode, enable_python_dispatcher(): |
| decomp_vjp_fn(cotangents) |
| if run_without_python_dispatcher(mode): |
| # without this check, incorrect decomps at the python dispatcher level can still pass because |
| # they're checking aten decomps at the torch_dispatch level. |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode: |
| decomp_vjp_fn(cotangents) |
| if not run_all: |
| self.check_decomposed(op.aten_backward_name, mode) |
| |
| elif aten_name in decomposition_names or run_all: |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| # A failure here might be because the decomposition for the op is wrong or because a |
| # decomposition used by the particular op is wrong. |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode, enable_python_dispatcher(): |
| func(*args, **kwargs) |
| |
| if run_without_python_dispatcher(mode): |
| # without this check, incorrect decomps at the python dispatcher level can still pass because |
| # they're checking aten decomps at the torch_dispatch level. |
| with self.DecompCrossRefMode( |
| self, self.precision, self.rel_tol, dtype, run_all |
| ) as mode: |
| func(*args, **kwargs) |
| |
| if not run_all: |
| self.check_decomposed(aten_name, mode) |
| else: |
| assert op.supports_autograd |
| self.skipTest( |
| "only backwards is decomposed, but dtype doesn't support AD" |
| ) |
| |
| |
| instantiate_device_type_tests(TestDecomp, globals()) |
| |
| |
| class DecompOneOffTests(TestCase): |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| def test_contiguous_softmax(self, device): |
| size = (2, 4, 3, 3) |
| stride = (9, 18, 3, 1) |
| dtype = torch.float32 |
| |
| x = torch.randn(size, dtype=dtype, device=device) |
| x = torch.as_strided(x, size, stride) |
| |
| ref = torch.ops.aten._softmax(x, -1, False) |
| res = torch._decomp.decompositions._softmax(x, -1, False) |
| self.assertEqual(ref.stride(), res.stride()) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| def test_contiguous_log_softmax(self, device): |
| size = (2, 4, 3, 3) |
| stride = (9, 18, 3, 1) |
| |
| dtype = torch.float32 |
| x = torch.randn(size, dtype=dtype, device=device) |
| x = torch.as_strided(x, size, stride) |
| |
| ref = torch.ops.aten._log_softmax(x, -1, False) |
| res = torch._decomp.decompositions._log_softmax(x, -1, False) |
| self.assertEqual(ref.stride(), res.stride()) |
| |
| @onlyCUDA |
| def test_exponential_non_inf(self, device): |
| inp = torch.empty((4, 400, 256), device=device) |
| |
| with torch._dynamo.utils.preserve_rng_state(): |
| exp_ref = inp.exponential_() |
| exp = torch._refs.exponential(inp) |
| |
| self.assertEqual(exp, exp_ref) |
| self.assertFalse(exp.isinf().any()) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @skipIfCrossRef |
| @onlyCUDA |
| def test_amp_batch_norm_backward(self): |
| device = "cuda" |
| grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) |
| x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) |
| weight = torch.randn((2,), dtype=torch.float32, device=device) |
| rmean = torch.randn((2,), dtype=torch.float32, device=device) |
| rvar = torch.randn((2,), dtype=torch.float32, device=device) |
| mean = torch.randn((0,), dtype=torch.float32, device=device) |
| |
| ref = torch.ops.aten.native_batch_norm_backward( |
| grad_out, |
| x, |
| weight, |
| rmean, |
| rvar, |
| mean, |
| mean, |
| False, |
| 1e-05, |
| [True, True, True], |
| ) |
| res = torch._decomp.decompositions.native_batch_norm_backward( |
| grad_out, |
| x, |
| weight, |
| rmean, |
| rvar, |
| mean, |
| mean, |
| False, |
| 1e-05, |
| [True, True, True], |
| ) |
| for a, b in zip(ref, res): |
| self.assertEqual(a.stride(), b.stride()) |
| self.assertEqual(a.dtype, b.dtype) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| def test_elu_backward(self, device): |
| size = (2, 4, 3, 3) |
| dtype = torch.float32 |
| grad_out = torch.randn(size, dtype=dtype, device=device) |
| out = torch.randn(size, dtype=dtype, device=device) |
| |
| ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) |
| res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) |
| self.assertEqual(ref, res) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| def test_threshold_backward_dtype(self, device): |
| grad = torch.randint(10, (4,), device=device) |
| input_tensor = torch.randint(10, (4,), device=device) |
| |
| ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) |
| res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) |
| self.assertEqual(ref.dtype, res.dtype) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| def test_weight_norm_interface(self, device): |
| g = torch.randn((3, 10, 10), device=device) |
| v = torch.randn((1, 1, 10), device=device) |
| |
| ref = torch.ops.aten._weight_norm_interface(g, v, 2) |
| res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) |
| self.assertTrue(torch.allclose(ref[0], res[0])) |
| self.assertTrue(torch.allclose(ref[1], res[1])) |
| |
| inp = torch.rand([30, 10], device=device) |
| inp2 = torch.rand([30, 1], device=device) |
| |
| self.assertEqual( |
| torch.ops.aten._weight_norm_interface(inp, inp2), |
| torch._decomp.decompositions._weight_norm_interface(inp, inp2), |
| ) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyCPU |
| @skipIfCrossRef |
| @skipOps( |
| "DecompOneOffTests", |
| "test_sdpa", |
| [ |
| xfail( |
| "nn.functional.scaled_dot_product_attention", |
| dtypes=[torch.half], |
| ), |
| ], |
| ) |
| @ops(_sdpa_op_info) |
| def test_sdpa(self, device, dtype, op): |
| # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we |
| # add support for float16 over there we should update this test as well. |
| |
| class ScaledDotProductAttention(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward( |
| self, query_layer, key_layer, value_layer, mask=None, is_causal=True |
| ): |
| attn_output = op( |
| query_layer, |
| key_layer, |
| value_layer, |
| attn_mask=mask, |
| dropout_p=0.0, |
| is_causal=is_causal, |
| ) |
| return attn_output |
| |
| query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) |
| key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) |
| value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) |
| masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] |
| |
| atol, rtol = dtype_precisions[dtype] |
| |
| for mask in masks: |
| is_causal = mask is None |
| attention = ScaledDotProductAttention() |
| decomposed_res = ( |
| torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( |
| query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask |
| ) |
| ) |
| eager_res = op( |
| query_layer, |
| key_layer, |
| value_layer, |
| attn_mask=mask, |
| dropout_p=0.0, |
| is_causal=is_causal, |
| ) |
| |
| self.assertTrue( |
| torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol) |
| ) |
| |
| |
| instantiate_device_type_tests(DecompOneOffTests, globals()) |
| |
| |
| class HasDecompTest(TestCase): |
| def setUp(self): |
| super().setUp() |
| self.maxDiff = None |
| |
| @staticmethod |
| def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: |
| has_tensor_arg = any( |
| "Tensor" in str(a.type) |
| for a in itertools.chain(op._schema.arguments, op._schema.returns) |
| ) |
| if not has_tensor_arg: |
| return False |
| |
| try: |
| # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions |
| return not op.has_kernel_for_dispatch_key( |
| DispatchKey.CompositeImplicitAutograd |
| ) |
| except RuntimeError as e: |
| # has_key fails for some jit-registered ops, which shouldn't be |
| # relevant here anyway |
| if "does not exist" in str(e): |
| return False |
| raise |
| |
| def test_has_decomposition(self): |
| def all_aten_overloads(): |
| for name in torch._C._dispatch_get_all_op_names(): |
| if not name.startswith("aten::"): |
| continue |
| |
| name = name[6:] |
| if "." in name: |
| packet_name, overload_name = name.split(".") |
| else: |
| packet_name, overload_name = name, "default" |
| |
| packet = getattr(aten, packet_name) |
| assert isinstance(packet, torch._ops.OpOverloadPacket) |
| op = getattr(packet, overload_name) |
| yield op |
| |
| # This is for operators that are only registered in some CI |
| # configurations, so would cause the test to fail |
| allow_list = {aten.get_gradients.default} |
| |
| overloads_wanting_decomp = { |
| op for op in all_aten_overloads() if self._can_appear_in_trace(op) |
| } |
| ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() |
| ops_missing_decomp -= allow_list |
| self.assertExpected( |
| "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) |
| ) |
| |
| def test_aten_core_operators(self): |
| # If a decomposition isn't included in the core decompositions, |
| # then it must decompose a core ATen operator. |
| # |
| # See NOTE [Core ATen Ops] |
| # |
| # If this test fails then either: |
| # - Add the decomposition to torch._decomp.core_aten_decompositions, |
| # if decomposition should be used by inductor (not a core operator). |
| # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of |
| # core ATen operators (and inductor will not use the decomposition). |
| |
| # Some decompositions are registered for CompositeImplicitAutograd |
| # operators, which never appear in AOTAutograd's graph so are never used. |
| useful_decomps = { |
| op |
| for op in decomposition_table.keys() |
| if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) |
| } |
| core_decomps = torch._decomp.core_aten_decompositions().keys() |
| core_aten_ops = useful_decomps - core_decomps |
| self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |