blob: 64cd71fdee6d67088caed632c8d9c7bbf34db677 [file] [log] [blame]
# Owner(s): ["module: unknown"]
from functools import partial, wraps
from itertools import chain
import torch
from torch.testing._internal.common_utils import \
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
# gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble])
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
check_batched_grad=None, check_batched_forward_grad=False):
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
# NB: check_backward_ad does not affect gradgradcheck (always True)
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
def is_inplace(variant):
if hasattr(variant, "__wrapped__"):
return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
for sample in samples:
if sample.broadcasts_input and is_inplace(variant):
continue
# Gradcheck expects tensors as its input, but autograd actually supports tensorlists
# and tensors passed as kwargs. The following creates a function that accepts just
# the tensors that require grad as varargs, and then recomposes them back into the
# original input.
# Creates gradcheck inputs by identifying tensors requiring grad
all_args = None
if is_iterable_of_tensors(sample.input):
all_args = chain(sample.input, sample.args, sample.kwargs.values())
else:
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
def _input_recomposition_helper(inputs, inp, input_idx):
if is_iterable_of_tensors(inp):
tensor_list = []
for x in inp:
if isinstance(x, torch.Tensor) and x.requires_grad:
tensor_list.append(inputs[input_idx])
input_idx = input_idx + 1
else:
tensor_list.append(x)
return tensor_list, input_idx
elif isinstance(inp, torch.Tensor) and inp.requires_grad:
return inputs[input_idx], input_idx + 1
else:
return inp, input_idx
def fn(*inputs):
# Puts inputs back into sample properly
positional_args = []
input_idx = 0
inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
positional_args.append(inp)
for x in sample.args:
inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
positional_args.append(inp)
# Recreates kwargs
kwargs = {}
for k, v in sample.kwargs.items():
inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
kwargs[k] = inp
output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
if sample.output_process_fn_grad is not None:
return sample.output_process_fn_grad(output)
return output
if check == 'gradcheck':
if check_batched_grad is None:
check_batched_grad = op.check_batched_grad
self.assertTrue(gradcheck(fn, gradcheck_args,
check_batched_grad=check_batched_grad,
check_grad_dtypes=True,
nondet_tol=op.gradcheck_nondet_tol,
fast_mode=op.gradcheck_fast_mode,
check_forward_ad=check_forward_ad,
check_backward_ad=check_backward_ad,
check_undefined_grad=True,
check_batched_forward_grad=check_batched_forward_grad))
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
for gen_non_contig_grad_outputs in (False, True):
kwargs = {
"gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
"check_batched_grad": op.check_batched_gradgrad,
"check_grad_dtypes": True,
"nondet_tol": op.gradcheck_nondet_tol,
"fast_mode": op.gradcheck_fast_mode
}
if check == "fwgrad_bwgrad":
kwargs["check_fwd_over_rev"] = True
kwargs["check_rev_over_rev"] = False
kwargs["check_batched_grad"] = False
kwargs["check_undefined_grad"] = False
self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
check_batched_grad=None, check_batched_forward_grad=False):
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
check_batched_forward_grad=check_batched_forward_grad)
def _skip_helper(self, op, device, dtype):
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
if not op.supports_autograd and not op.supports_forward_ad:
self.skipTest("Skipped! autograd not supported.")
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_grad(self, device, dtype, op):
# This is verified by test_dtypes in test_ops.py
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
else:
self._grad_test_helper(device, dtype, op, op.get_op())
# Method grad (and gradgrad, see below) tests are disabled since they're
# costly and redundant with function grad (and gradgad) tests
# @_gradcheck_ops(op_db)
# def test_method_grad(self, device, dtype, op):
# self._skip_helper(op, device, dtype)
# self._grad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_grad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant:
self.skipTest("Op has no inplace variant!")
# Verifies an operation doesn't support inplace autograd if it claims not to
if not op.supports_inplace_autograd:
inplace = self._get_safe_inplace(op.get_inplace())
for sample in op.sample_inputs(device, dtype, requires_grad=True):
if sample.broadcasts_input:
continue
with self.assertRaises(Exception):
result = inplace(sample)
result.sum().backward()
else:
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:
self.skipTest("Op claims it doesn't support gradgrad. This is not verified.")
else:
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
# Test that forward-over-reverse gradgrad is computed correctly
@_gradcheck_ops(op_db)
def test_fn_fwgrad_bwgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_fwgrad_bwgrad:
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
else:
err_msg = r"Trying to use forward AD with .* that does not support it"
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
# Test that gradients of gradients are properly raising
@_gradcheck_ops(op_db)
def test_fn_fail_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if op.supports_gradgrad:
self.skipTest("Skipped! Operation does support gradgrad")
err_msg = r"derivative for .* is not implemented"
with self.assertRaisesRegex(RuntimeError, err_msg):
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
# Method gradgrad (and grad, see above) tests are disabled since they're
# costly and redundant with function gradgrad (and grad) tests
# @_gradcheck_ops(op_db)
# def test_method_gradgrad(self, device, dtype, op):
# self._skip_helper(op, device, dtype)
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
# TODO: clean up how attributes are passed to gradcheck from OpInfos
def call_grad_test_helper():
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
(op.check_inplace_batched_forward_grad and is_inplace))
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
if op.supports_forward_ad:
call_grad_test_helper()
else:
err_msg = r"Trying to use forward AD with .* that does not support it"
hint_msg = ("Running forward AD for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
call_grad_test_helper()
@_gradcheck_ops(op_db)
def test_forward_mode_AD(self, device, dtype, op):
self._skip_helper(op, device, dtype)
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
@_gradcheck_ops(op_db)
def test_inplace_forward_mode_AD(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
instantiate_device_type_tests(TestGradients, globals())
if __name__ == '__main__':
run_tests()