blob: 100859713a26cd60bd6c2b3801624d84d43c9d4a [file] [log] [blame]
# Owner(s): ["module: primTorch"]
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch._decomp import decomposition_table
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.testing._internal.logging_tensor import no_dispatch
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
TestCase,
skipIfCrossRef,
suppress_warnings,
TEST_WITH_ASAN,
run_tests,
)
from torch.testing._internal.common_device_type import (
onlyNativeDeviceTypes,
ops,
instantiate_device_type_tests,
)
from torch.testing._internal.common_methods_invocations import op_db
import itertools
import functools
from functools import partial
import unittest
aten = torch.ops.aten
# TODO: this isn't going to work with non-aten namespaces
def overload_to_aten_name(overload):
return overload._schema.name.split("::")[1]
# All operators that can have decomp tests
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
_decomp_test_ops = [
op
for op in op_db
if op.aten_name in decomposition_names
or op.aten_backward_name in decomposition_names
]
def diff_arg(arg, requires_grad=True):
def is_differentiable_arg(arg):
if requires_grad:
return arg.requires_grad
else:
return arg.is_floating_point() or arg.is_complex()
if is_iterable_of_tensors(arg):
if all([is_differentiable_arg(a) for a in arg]):
return True
if all([not is_differentiable_arg(a) for a in arg]):
return False
raise RuntimeError("NYI: The test runner can't handle this")
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
# Version of autograd.grad with some differences:
# - pytree inputs is allowed (but leaves of the pytree have to all
# be tensors)
# - if an input is not used as part of derivatives, we will return a
# zero-filled tensor for the result
def _autograd_grad(
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
inputs, inputs_spec = tree_flatten(inputs)
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
diff_grad_outputs = [
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
]
if len(diff_grad_outputs) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
grad_inputs = torch.autograd.grad(
diff_outputs,
diff_inputs,
grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True,
)
result = []
grad_inputs_iter = iter(grad_inputs)
for inp in inputs:
if inp.requires_grad:
grad_input = next(grad_inputs_iter)
if grad_input is None:
result.append(torch.zeros_like(inp))
else:
result.append(grad_input)
else:
result.append(torch.zeros_like(inp))
return tree_unflatten(result, inputs_spec)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
def ref_vjp_no_create(f, *primals):
result = f(*primals)
def wrapped(cotangents):
return _autograd_grad(
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
)
return result, wrapped
dtype_precisions = {
torch.float16: (0.001, 1e-5),
torch.bfloat16: (0.016, 1e-4),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
# Returns the "default" rtol and atol for comparing scalars or
# tensors of the given dtypes.
def _getDefaultRtolAndAtol(dtype0, dtype1):
rtol = max(
dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
)
atol = max(
dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
)
return rtol, atol
def op_assert_ref(test_case, op, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()
decomp_diff = (decomp - ref).abs().max()
atol = 1e-10
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
)
else:
test_case.assertEqual(
orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
)
def op_assert_equal(test_case, op, orig, decomp, args, kwargs):
test_case.assertEqual(
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
# Before adding an entry to this table, make sure your decomposition is right :)
tol_table = {
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
(torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
(torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
1e-3,
1e-3,
),
}
if (decomp.dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
# Given f, returns an f' such that:
# - f' takes only positional arguments
# - All arguments to f' are floating-point Tensors
# - All outputs of f' are floating-point Tensors
def normalize_op_input_output2(
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
):
flat_args, args_spec = tree_flatten(args)
diff_argnums = tuple(
i
for i, arg in enumerate(flat_args)
if diff_arg(arg, requires_grad=requires_grad)
)
assert len(diff_argnums) > 0
primals = tuple(flat_args[i] for i in diff_argnums)
@functools.wraps(f)
def wrapped(*primals):
_args = list(flat_args)
for num, arg in zip(diff_argnums, primals):
_args[num] = arg
_args = tree_unflatten(_args, args_spec)
result = f(*_args, **kwargs)
if output_process_fn_grad is not None:
result = output_process_fn_grad(result)
if isinstance(result, tuple):
# TODO: Remove the following hack for namedtuples
result = tuple(result)
result = tuple(
r
for r in result
if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
)
assert len(result) > 0
return result
return wrapped, primals
# NB: This also upcasts dtype arguments
def upcast_tensor(func, x, dtype=torch.float32):
# Some functions take a dtype as argument, so we need to
# manually change that dtype in order to run it with a
# higher precision
dtype_arg_table = {
torch.ops.aten._softmax_backward_data.default,
torch.ops.aten._log_softmax_backward_data.default,
}
if isinstance(x, Tensor) and x.dtype.is_floating_point:
return x.to(dtype=dtype)
elif (
isinstance(x, torch.dtype)
and func in dtype_arg_table
and x in [torch.float16, torch.bfloat16]
):
return torch.float64
else:
return x
def normalize_op_input_output(f, sample, requires_grad=True):
args = tuple([sample.input] + list(sample.args))
return normalize_op_input_output2(
f,
args,
sample.kwargs,
sample.output_process_fn_grad,
requires_grad=requires_grad,
)
CROSS_REF_EXCLUDE_SET = {
# CUBLAS_STATUS_NOT_SUPPORTED when calling
# `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
# (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
# (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
("cuda", torch.bfloat16, "nn.functional.bilinear"),
# randomness
("cuda", torch.float16, "nn.functional.dropout"),
("cuda", torch.bfloat16, "nn.functional.dropout"),
("cuda", torch.float64, "nn.functional.dropout"),
("cuda", torch.float32, "nn.functional.dropout"),
# decomp has problem even with opmath
("cuda", torch.bfloat16, "nn.functional.layer_norm"),
("cuda", torch.float16, "nn.functional.layer_norm"),
("cuda", torch.bfloat16, "nn.functional.batch_norm"),
("cuda", torch.float16, "nn.functional.batch_norm"),
("cuda", torch.bfloat16, "nn.functional.instance_norm"),
("cuda", torch.float16, "nn.functional.instance_norm"),
# doesn't work
("cuda", torch.bfloat16, "nn.functional.embedding"),
}
all_decomposed = set()
all_called = defaultdict(int)
# Helpful snippet for testing coverage
"""
import atexit
def check_coverage():
print("missing coverage:")
print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
atexit.register(check_coverage)
"""
# Helpful snippet for Horace to create his google sheet :)
"""
import atexit
def dump_ops():
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
f.write(f'{op.__name__}\n')
g.write(f'{count}\n')
with open('run_decompositions.txt', 'w') as f:
for op in sorted([i.__name__ for i in all_decomposed]):
f.write(f'{op}\n')
atexit.register(dump_ops)
"""
def any_unsupported(args, kwargs):
def test_unsupported(t):
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
# These are all things that we haven't coded decompositions
# to handle correctly. Maybe they should.
return any([
t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
t.is_nested, torch._is_functional_tensor(t),
])
elif torch.overrides.is_tensor_like(t):
# Decompositions will generally change the behavior of Tensor-like
# subclasses, so bypass tests in this case too
return True
else:
return False
flat_args, _ = tree_flatten(args)
flat_kwargs, _ = tree_flatten(kwargs)
return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs))
class TestDecomp(TestCase):
longMessage = True
# NB: This actually overlaps with test_comprehensive, but it only
# runs on things that are definitely decomposed so it's a lot faster
# to run
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ops(_decomp_test_ops)
def test_quick(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=False)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ops(op_db)
def test_comprehensive(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=True)
def do_cross_ref(self, device, dtype, op, *, run_all):
if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or (
None,
dtype,
op.name,
) in CROSS_REF_EXCLUDE_SET:
self.skipTest(f"{op.name} in {dtype} not supported")
test_dtype = dtype
# We check the correctness of each decomposition right after running it.
# So, when we encounter a decomposition, we run the function normally, and
# then run the decomposition, and ensure they're identical.
called = set()
decomposed = set()
saved_precision = self.precision
saved_rel_tol = self.rel_tol
class DecompCrossRefMode(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch():
return cls._torch_dispatch(func, types, args, kwargs)
@classmethod
def _torch_dispatch(cls, func, types, args=(), kwargs=None):
self.precision = saved_precision
self.rel_tol = saved_rel_tol
called.add(func)
all_called[func] += 1
# Stuff we shouldn't bother testing
# (TODO: remove detach from the decomp table?)
if func not in decomposition_table or func in [
torch.ops.aten.detach.default
] or any_unsupported(args, kwargs):
return func(*args, **kwargs)
decomposed.add(func)
all_decomposed.add(func)
# We take 2 main strategies for verifying correctness/numerical stability of decompositions
# The first one is simply tolerance checking between decomp_out and pytorch_out
# However, for fp16/bf16 and reductions, this becomes very
# finicky, as there are not many guarantees we can make.
# So, for fp16/bf16, we instead compare the difference of
# {decomp_out, pytorch_out_64} and {pytorch_out,
# pytorch_out_64}. In other words, we compare how far the
# decomposition and pytorch are from the "ground truth" (i.e.
# fp64). If the decomposition results in more error, we error
decomposition = decomposition_table[func]
do_relative_check = test_dtype in [torch.float16, torch.bfloat16]
real_out_unflat = func(*args, **kwargs)
real_out, _ = tree_flatten(real_out_unflat)
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
assert len(real_out) == len(decomp_out)
if do_relative_check:
upcast = partial(upcast_tensor, func, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
if orig is None:
assert decomp is None
continue
op_assert_ref(self, func, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if orig is None:
assert decomp is None
continue
op_assert_equal(self, func, orig, decomp, args, kwargs)
return real_out_unflat
requires_grad = (
op.supports_autograd
and dtype in op.supported_backward_dtypes(torch.device(device).type)
# TODO: OpInfo really ought to error out for this case, but it's
# not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work
and not dtype == torch.complex32
)
samples = op.sample_inputs(device, test_dtype, requires_grad=requires_grad)
def check_decomposed(aten_name):
self.assertTrue(
any(overload_to_aten_name(c) == aten_name for c in decomposed),
msg=f"aten.{aten_name} was not decomposed, saw calls for: "
+ ", ".join(map(str, list(called))),
)
aten_name = op.decomp_aten_name or op.aten_name
func = op.get_op()
for sample_input in samples:
if requires_grad:
fn, primals = normalize_op_input_output(func, sample_input)
primals = tree_map(
lambda x: x if isinstance(x, torch.Tensor) else x, primals
)
# Once https://github.com/pytorch/pytorch/pull/75965/ I can
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode):
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
check_decomposed(aten_name)
if op.aten_backward_name in decomposition_names or run_all:
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode):
decomp_vjp_fn(cotangents)
if not run_all:
check_decomposed(op.aten_backward_name)
elif aten_name in decomposition_names or run_all:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode):
func(*args, **kwargs)
if not run_all:
check_decomposed(aten_name)
else:
assert op.supports_autograd
self.skipTest(
"only backwards is decomposed, but dtype doesn't support AD"
)
instantiate_device_type_tests(TestDecomp, globals())
if __name__ == "__main__":
run_tests()