| # Owner(s): ["module: primTorch"] |
| |
| from collections import defaultdict |
| from torch import Tensor |
| import torch.autograd |
| from torch.utils._python_dispatch import enable_torch_dispatch_mode |
| from torch._decomp import decomposition_table |
| |
| from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten |
| from torch.testing._internal.logging_tensor import no_dispatch |
| from torch.testing._internal.common_utils import ( |
| is_iterable_of_tensors, |
| TestCase, |
| skipIfCrossRef, |
| suppress_warnings, |
| TEST_WITH_ASAN, |
| run_tests, |
| ) |
| from torch.testing._internal.common_device_type import ( |
| onlyNativeDeviceTypes, |
| ops, |
| instantiate_device_type_tests, |
| ) |
| from torch.testing._internal.common_methods_invocations import op_db |
| |
| import itertools |
| import functools |
| from functools import partial |
| import unittest |
| |
| aten = torch.ops.aten |
| |
| |
| # TODO: this isn't going to work with non-aten namespaces |
| def overload_to_aten_name(overload): |
| return overload._schema.name.split("::")[1] |
| |
| |
| # All operators that can have decomp tests |
| decomposition_names = {overload_to_aten_name(k) for k in decomposition_table} |
| _decomp_test_ops = [ |
| op |
| for op in op_db |
| if op.aten_name in decomposition_names |
| or op.aten_backward_name in decomposition_names |
| ] |
| |
| |
| def diff_arg(arg, requires_grad=True): |
| def is_differentiable_arg(arg): |
| if requires_grad: |
| return arg.requires_grad |
| else: |
| return arg.is_floating_point() or arg.is_complex() |
| |
| if is_iterable_of_tensors(arg): |
| if all([is_differentiable_arg(a) for a in arg]): |
| return True |
| if all([not is_differentiable_arg(a) for a in arg]): |
| return False |
| raise RuntimeError("NYI: The test runner can't handle this") |
| return isinstance(arg, Tensor) and is_differentiable_arg(arg) |
| |
| |
| # Version of autograd.grad with some differences: |
| # - pytree inputs is allowed (but leaves of the pytree have to all |
| # be tensors) |
| # - if an input is not used as part of derivatives, we will return a |
| # zero-filled tensor for the result |
| def _autograd_grad( |
| outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True |
| ): |
| inputs, inputs_spec = tree_flatten(inputs) |
| diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) |
| if grad_outputs is None: |
| diff_outputs = tuple(out for out in outputs if out.requires_grad) |
| else: |
| diff_grad_outputs = [ |
| (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad |
| ] |
| if len(diff_grad_outputs) == 0: |
| diff_outputs, grad_outputs = (), () |
| else: |
| diff_outputs, grad_outputs = zip(*diff_grad_outputs) |
| grad_inputs = torch.autograd.grad( |
| diff_outputs, |
| diff_inputs, |
| grad_outputs, |
| retain_graph=retain_graph, |
| create_graph=create_graph, |
| allow_unused=True, |
| ) |
| result = [] |
| grad_inputs_iter = iter(grad_inputs) |
| for inp in inputs: |
| if inp.requires_grad: |
| grad_input = next(grad_inputs_iter) |
| if grad_input is None: |
| result.append(torch.zeros_like(inp)) |
| else: |
| result.append(grad_input) |
| else: |
| result.append(torch.zeros_like(inp)) |
| return tree_unflatten(result, inputs_spec) |
| |
| |
| def _as_tuple(val): |
| if isinstance(val, tuple): |
| return val |
| return (val,) |
| |
| |
| def ref_vjp_no_create(f, *primals): |
| result = f(*primals) |
| |
| def wrapped(cotangents): |
| return _autograd_grad( |
| _as_tuple(result), primals, _as_tuple(cotangents), create_graph=False |
| ) |
| |
| return result, wrapped |
| |
| |
| dtype_precisions = { |
| torch.float16: (0.001, 1e-5), |
| torch.bfloat16: (0.016, 1e-4), |
| torch.float32: (1.3e-6, 1e-5), |
| torch.float64: (1e-7, 1e-7), |
| torch.complex32: (0.001, 1e-5), |
| torch.complex64: (1.3e-6, 1e-5), |
| torch.complex128: (1e-7, 1e-7), |
| } |
| # Returns the "default" rtol and atol for comparing scalars or |
| # tensors of the given dtypes. |
| |
| |
| def _getDefaultRtolAndAtol(dtype0, dtype1): |
| rtol = max( |
| dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] |
| ) |
| atol = max( |
| dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] |
| ) |
| return rtol, atol |
| |
| |
| def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs): |
| assert orig.dtype == decomp.dtype, f"Operation: {op}" |
| if orig.numel() == 0 or decomp.numel() == 0: |
| assert orig.numel() == decomp.numel() |
| return |
| if ref.is_floating_point(): |
| orig_diff = (orig - ref).abs().max() |
| decomp_diff = (decomp - ref).abs().max() |
| atol = 1e-10 |
| if decomp_diff > orig_diff + atol: |
| raise RuntimeError( |
| f"Difference from float64 is larger with decomposition {op.__name__}" |
| f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" |
| f"args = {args}\n" |
| f"kwargs = {kwargs}" |
| ) |
| else: |
| test_case.assertEqual( |
| orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" |
| ) |
| |
| |
| def op_assert_equal(test_case, op, orig, decomp, args, kwargs): |
| test_case.assertEqual( |
| orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}") |
| # Before adding an entry to this table, make sure your decomposition is right :) |
| tol_table = { |
| # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 |
| (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), |
| (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( |
| 1e-3, |
| 1e-3, |
| ), |
| } |
| if (decomp.dtype, op) in tol_table: |
| rtol, atol = tol_table[(decomp.dtype, op)] |
| else: |
| rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) |
| |
| test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}") |
| |
| |
| # Given f, returns an f' such that: |
| # - f' takes only positional arguments |
| # - All arguments to f' are floating-point Tensors |
| # - All outputs of f' are floating-point Tensors |
| def normalize_op_input_output2( |
| f, args, kwargs, output_process_fn_grad=None, requires_grad=True |
| ): |
| flat_args, args_spec = tree_flatten(args) |
| diff_argnums = tuple( |
| i |
| for i, arg in enumerate(flat_args) |
| if diff_arg(arg, requires_grad=requires_grad) |
| ) |
| assert len(diff_argnums) > 0 |
| primals = tuple(flat_args[i] for i in diff_argnums) |
| |
| @functools.wraps(f) |
| def wrapped(*primals): |
| _args = list(flat_args) |
| for num, arg in zip(diff_argnums, primals): |
| _args[num] = arg |
| _args = tree_unflatten(_args, args_spec) |
| result = f(*_args, **kwargs) |
| if output_process_fn_grad is not None: |
| result = output_process_fn_grad(result) |
| if isinstance(result, tuple): |
| # TODO: Remove the following hack for namedtuples |
| result = tuple(result) |
| result = tuple( |
| r |
| for r in result |
| if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) |
| ) |
| assert len(result) > 0 |
| return result |
| |
| return wrapped, primals |
| |
| |
| # NB: This also upcasts dtype arguments |
| |
| |
| def upcast_tensor(func, x, dtype=torch.float32): |
| # Some functions take a dtype as argument, so we need to |
| # manually change that dtype in order to run it with a |
| # higher precision |
| dtype_arg_table = { |
| torch.ops.aten._softmax_backward_data.default, |
| torch.ops.aten._log_softmax_backward_data.default, |
| } |
| |
| if isinstance(x, Tensor) and x.dtype.is_floating_point: |
| return x.to(dtype=dtype) |
| elif ( |
| isinstance(x, torch.dtype) |
| and func in dtype_arg_table |
| and x in [torch.float16, torch.bfloat16] |
| ): |
| return torch.float64 |
| else: |
| return x |
| |
| |
| def normalize_op_input_output(f, sample, requires_grad=True): |
| args = tuple([sample.input] + list(sample.args)) |
| return normalize_op_input_output2( |
| f, |
| args, |
| sample.kwargs, |
| sample.output_process_fn_grad, |
| requires_grad=requires_grad, |
| ) |
| |
| |
| CROSS_REF_EXCLUDE_SET = { |
| # CUBLAS_STATUS_NOT_SUPPORTED when calling |
| # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, |
| # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, |
| # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, |
| # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` |
| ("cuda", torch.bfloat16, "nn.functional.bilinear"), |
| # randomness |
| ("cuda", torch.float16, "nn.functional.dropout"), |
| ("cuda", torch.bfloat16, "nn.functional.dropout"), |
| ("cuda", torch.float64, "nn.functional.dropout"), |
| ("cuda", torch.float32, "nn.functional.dropout"), |
| # decomp has problem even with opmath |
| ("cuda", torch.bfloat16, "nn.functional.layer_norm"), |
| ("cuda", torch.float16, "nn.functional.layer_norm"), |
| ("cuda", torch.bfloat16, "nn.functional.batch_norm"), |
| ("cuda", torch.float16, "nn.functional.batch_norm"), |
| ("cuda", torch.bfloat16, "nn.functional.instance_norm"), |
| ("cuda", torch.float16, "nn.functional.instance_norm"), |
| # doesn't work |
| ("cuda", torch.bfloat16, "nn.functional.embedding"), |
| |
| } |
| |
| all_decomposed = set() |
| all_called = defaultdict(int) |
| |
| # Helpful snippet for testing coverage |
| """ |
| import atexit |
| def check_coverage(): |
| print("missing coverage:") |
| print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) |
| atexit.register(check_coverage) |
| """ |
| |
| # Helpful snippet for Horace to create his google sheet :) |
| """ |
| import atexit |
| def dump_ops(): |
| with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: |
| for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): |
| f.write(f'{op.__name__}\n') |
| g.write(f'{count}\n') |
| with open('run_decompositions.txt', 'w') as f: |
| for op in sorted([i.__name__ for i in all_decomposed]): |
| f.write(f'{op}\n') |
| |
| atexit.register(dump_ops) |
| """ |
| |
| |
| def any_unsupported(args, kwargs): |
| def test_unsupported(t): |
| if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: |
| # These are all things that we haven't coded decompositions |
| # to handle correctly. Maybe they should. |
| return any([ |
| t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized, |
| t.is_nested, torch._is_functional_tensor(t), |
| ]) |
| elif torch.overrides.is_tensor_like(t): |
| # Decompositions will generally change the behavior of Tensor-like |
| # subclasses, so bypass tests in this case too |
| return True |
| else: |
| return False |
| |
| flat_args, _ = tree_flatten(args) |
| flat_kwargs, _ = tree_flatten(kwargs) |
| return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs)) |
| |
| |
| class TestDecomp(TestCase): |
| longMessage = True |
| |
| # NB: This actually overlaps with test_comprehensive, but it only |
| # runs on things that are definitely decomposed so it's a lot faster |
| # to run |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| @suppress_warnings |
| @ops(_decomp_test_ops) |
| def test_quick(self, device, dtype, op): |
| self.do_cross_ref(device, dtype, op, run_all=False) |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") |
| @onlyNativeDeviceTypes |
| @skipIfCrossRef |
| @suppress_warnings |
| @ops(op_db) |
| def test_comprehensive(self, device, dtype, op): |
| self.do_cross_ref(device, dtype, op, run_all=True) |
| |
| def do_cross_ref(self, device, dtype, op, *, run_all): |
| if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or ( |
| None, |
| dtype, |
| op.name, |
| ) in CROSS_REF_EXCLUDE_SET: |
| self.skipTest(f"{op.name} in {dtype} not supported") |
| |
| test_dtype = dtype |
| |
| # We check the correctness of each decomposition right after running it. |
| # So, when we encounter a decomposition, we run the function normally, and |
| # then run the decomposition, and ensure they're identical. |
| called = set() |
| decomposed = set() |
| |
| saved_precision = self.precision |
| saved_rel_tol = self.rel_tol |
| |
| class DecompCrossRefMode(torch.Tensor): |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| with no_dispatch(): |
| return cls._torch_dispatch(func, types, args, kwargs) |
| |
| @classmethod |
| def _torch_dispatch(cls, func, types, args=(), kwargs=None): |
| self.precision = saved_precision |
| self.rel_tol = saved_rel_tol |
| |
| called.add(func) |
| all_called[func] += 1 |
| |
| # Stuff we shouldn't bother testing |
| # (TODO: remove detach from the decomp table?) |
| if func not in decomposition_table or func in [ |
| torch.ops.aten.detach.default |
| ] or any_unsupported(args, kwargs): |
| return func(*args, **kwargs) |
| |
| decomposed.add(func) |
| all_decomposed.add(func) |
| |
| # We take 2 main strategies for verifying correctness/numerical stability of decompositions |
| # The first one is simply tolerance checking between decomp_out and pytorch_out |
| # However, for fp16/bf16 and reductions, this becomes very |
| # finicky, as there are not many guarantees we can make. |
| # So, for fp16/bf16, we instead compare the difference of |
| # {decomp_out, pytorch_out_64} and {pytorch_out, |
| # pytorch_out_64}. In other words, we compare how far the |
| # decomposition and pytorch are from the "ground truth" (i.e. |
| # fp64). If the decomposition results in more error, we error |
| |
| decomposition = decomposition_table[func] |
| |
| do_relative_check = test_dtype in [torch.float16, torch.bfloat16] |
| real_out_unflat = func(*args, **kwargs) |
| real_out, _ = tree_flatten(real_out_unflat) |
| decomp_out, _ = tree_flatten(decomposition(*args, **kwargs)) |
| assert len(real_out) == len(decomp_out) |
| |
| if do_relative_check: |
| upcast = partial(upcast_tensor, func, dtype=torch.float64) |
| real_out_double, _ = tree_flatten( |
| func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) |
| ) |
| for orig, decomp, ref in zip(real_out, decomp_out, real_out_double): |
| if orig is None: |
| assert decomp is None |
| continue |
| op_assert_ref(self, func, orig, decomp, ref, args, kwargs) |
| else: |
| for orig, decomp in zip(real_out, decomp_out): |
| if orig is None: |
| assert decomp is None |
| continue |
| op_assert_equal(self, func, orig, decomp, args, kwargs) |
| |
| return real_out_unflat |
| |
| requires_grad = ( |
| op.supports_autograd |
| and dtype in op.supported_backward_dtypes(torch.device(device).type) |
| # TODO: OpInfo really ought to error out for this case, but it's |
| # not exercised in test_ops_gradients atm. The problem is not |
| # complex32 per-se (which is supported by data movement only ops) |
| # but that when we do backwards we expect other ops like add to work |
| and not dtype == torch.complex32 |
| ) |
| samples = op.sample_inputs(device, test_dtype, requires_grad=requires_grad) |
| |
| def check_decomposed(aten_name): |
| self.assertTrue( |
| any(overload_to_aten_name(c) == aten_name for c in decomposed), |
| msg=f"aten.{aten_name} was not decomposed, saw calls for: " |
| + ", ".join(map(str, list(called))), |
| ) |
| |
| aten_name = op.decomp_aten_name or op.aten_name |
| |
| func = op.get_op() |
| for sample_input in samples: |
| if requires_grad: |
| fn, primals = normalize_op_input_output(func, sample_input) |
| primals = tree_map( |
| lambda x: x if isinstance(x, torch.Tensor) else x, primals |
| ) |
| |
| # Once https://github.com/pytorch/pytorch/pull/75965/ I can |
| # store the called list on the mode object instance and no |
| # explicit clearing is necessary as I will create a fresh mode |
| # for each region |
| decomposed.clear() |
| with enable_torch_dispatch_mode(DecompCrossRefMode): |
| decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) |
| if aten_name in decomposition_names: |
| check_decomposed(aten_name) |
| |
| if op.aten_backward_name in decomposition_names or run_all: |
| cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) |
| |
| decomposed.clear() |
| with enable_torch_dispatch_mode(DecompCrossRefMode): |
| decomp_vjp_fn(cotangents) |
| if not run_all: |
| check_decomposed(op.aten_backward_name) |
| |
| elif aten_name in decomposition_names or run_all: |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| decomposed.clear() |
| with enable_torch_dispatch_mode(DecompCrossRefMode): |
| func(*args, **kwargs) |
| if not run_all: |
| check_decomposed(aten_name) |
| else: |
| assert op.supports_autograd |
| self.skipTest( |
| "only backwards is decomposed, but dtype doesn't support AD" |
| ) |
| |
| |
| instantiate_device_type_tests(TestDecomp, globals()) |
| |
| if __name__ == "__main__": |
| run_tests() |