blob: b0361306443ab8132ec03af2f090ca548b6e56a6 [file] [log] [blame]
# Owner(s): ["module: autograd"]
import contextlib
import warnings
import numpy as np
import torch
from torch.library import _scoped_library, Library
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
@contextlib.contextmanager
def autograd_fallback_mode(mode):
prev = torch._C._get_autograd_fallback_mode()
try:
torch._C._set_autograd_fallback_mode(mode)
yield
finally:
torch._C._set_autograd_fallback_mode(prev)
class TestAutogradFallback(TestCase):
test_ns = "_test_autograd_fallback"
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
delattr(torch.ops, self.test_ns)
if hasattr(self, "lib"):
del self.lib.m
del self.lib
def get_op(self, name):
return getattr(getattr(torch.ops, self.test_ns), name).default
def get_lib(self):
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
self.lib = lib
return lib
@parametrize("mode", ("nothing", "warn"))
def test_no_grad(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
op = self.get_op("foo")
with warnings.catch_warnings():
warnings.simplefilter("error")
with torch.no_grad():
a = torch.randn([], requires_grad=True)
b = torch.randn([], requires_grad=True)
out = op(a, b, 1)
self.assertFalse(out.requires_grad)
with warnings.catch_warnings():
warnings.simplefilter("error")
a = torch.randn([])
b = torch.randn([])
out = op(a, b, 1)
self.assertFalse(out.requires_grad)
@parametrize("mode", ("nothing", "warn"))
def test_no_autograd_kernel(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
op = self.get_op("foo")
def foo_impl(a, b, c):
result = a.detach().numpy() + b.detach().numpy() + c
return torch.tensor(result)
lib.impl("foo", foo_impl, "CPU")
# Some inputs requiring grad
a = torch.randn([], requires_grad=False)
b = torch.randn([], requires_grad=True)
out = op(a, b, 1).sum()
with self._check_ctx(mode, mode_nothing_raises=True):
out.backward()
self.assertIsNone(b.grad)
def _check_ctx(self, mode, *, mode_nothing_raises=False):
if mode == "warn":
return self.assertWarnsRegex(
UserWarning, "an autograd kernel was not registered"
)
assert mode == "nothing"
if mode_nothing_raises:
return self.assertRaisesRegex(RuntimeError, "does not require grad")
return contextlib.nullcontext()
@parametrize("mode", ("nothing", "warn"))
def test_no_autograd_kernel_inplace(self, mode):
with autograd_fallback_mode(mode):
# input modified in-place gets returned as output
lib = self.get_lib()
lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
op = self.get_op("foo")
def foo_impl(x, y):
with torch.no_grad():
x.sin_()
y.cos_()
return x, y
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3, requires_grad=True)
w = x.clone()
v = x.clone()
y0 = w[0]
y1 = v[1]
z0, z1 = op(y0, y1)
for tensor in [w, v, z0, z1, y0, y1]:
with self._check_ctx(mode):
tensor.sum().backward(retain_graph=True)
# no outputs: we don't do anything. Maybe we should in the future.
# This is not a common failure mode.
lib.define("bar(Tensor(a!) self) -> ()")
op = self.get_op("bar")
def bar_impl(x):
with torch.no_grad():
x.sin_()
lib.impl("bar", bar_impl, "CPU")
with warnings.catch_warnings():
warnings.simplefilter("error")
x = torch.randn([], requires_grad=True)
y = x.clone()
z = op(y)
y.backward()
self.assertEqual(x.grad, torch.ones_like(x))
@parametrize("mode", ("nothing", "warn"))
def test_cpu_return_self(self, mode):
with autograd_fallback_mode(mode):
# To be clear, none of these situations are OK and will lead
# to other problems down the line. We're testing them because
# it is fairly common to actually do these things.
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor self) -> Tensor")
lib.impl("foo", lambda x: x, "CPU")
op = self.get_op("foo")
x = torch.randn(3, requires_grad=True)
y = op(x).sum()
with self._check_ctx(mode):
y.backward()
self.assertEqual(x.grad, torch.ones_like(x))
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
lib.impl("bar", lambda x: x, "CPU")
op = self.get_op("bar")
x = torch.randn(3, requires_grad=True)
y = op(x).sum()
with self._check_ctx(mode):
y.backward()
self.assertEqual(x.grad, torch.ones_like(x))
@parametrize("mode", ("nothing", "warn"))
def test_composite_registered_to_cpu(self, mode):
with autograd_fallback_mode(mode):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor self) -> Tensor")
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
op = self.get_op("foo")
x = torch.randn(3, requires_grad=True)
y = op(x)
with self._check_ctx(mode):
y.backward()
self.assertEqual(x.grad, x.cos())
@parametrize("mode", ("nothing", "warn"))
def test_autograd_function_registered_to_cpu(self, mode):
with autograd_fallback_mode(mode):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor self) -> Tensor")
class NumpySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.tensor(np.sin(x.cpu().numpy()))
@staticmethod
def backward(ctx, gx):
(x,) = ctx.saved_tensors
return gx * x.cos()
lib.impl("foo", NumpySin.apply, "CPU")
op = self.get_op("foo")
x = torch.randn(3, requires_grad=True)
y = op(x).sum()
with self._check_ctx(mode):
y.backward()
self.assertEqual(x.grad, x.cos())
@parametrize("mode", ("nothing", "warn"))
def test_inplace_autograd_function_registered_to_cpu(self, mode):
with autograd_fallback_mode(mode):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
class NumpySin_(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x.clone())
x_np = x.detach().numpy()
np.sin(x_np, out=x_np)
ctx.mark_dirty(x)
return x
@staticmethod
def backward(ctx, gx):
(x,) = ctx.saved_tensors
return gx * x.cos()
lib.impl("foo", NumpySin_.apply, "CPU")
op = self.get_op("foo")
x = torch.randn(3, requires_grad=True)
z = x.clone()
w = z[0]
y = op(w)
expected = torch.zeros_like(x)
expected[0] = x[0].cos()
with self._check_ctx(mode):
(gx,) = torch.autograd.grad(
y, x, torch.ones_like(y), retain_graph=True
)
self.assertEqual(gx, expected)
expected = torch.ones_like(x)
expected[0] = x[0].cos()
with self._check_ctx(mode):
(gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
self.assertEqual(gx, expected)
@parametrize("mode", ("nothing", "warn"))
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
# We don't do anything special (that is, we don't rebase history).
# See NOTE [autograd fallback and in-place operations] for why
with autograd_fallback_mode(mode):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
# Correct usage of (a!)
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
def foo_impl(x, y):
x_d = x.detach()
y = y.detach()
x_d.add_(y)
return x
lib.impl("foo", foo_impl, "CPU")
foo = self.get_op("foo")
# Incorrect usage of (a!): user doesn't return tensor as-is
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
def bar_impl(x, y):
x_d = x.detach()
y = y.detach()
x_d.add_(y)
return x_d.clone()
lib.impl("bar", bar_impl, "CPU")
bar = self.get_op("bar")
# User mutated input tensor but didn't return it.
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
def baz_impl(x, y):
x_d = x.detach()
y = y.detach()
x_d.add_(y)
lib.impl("baz", baz_impl, "CPU")
baz = self.get_op("baz")
# Test in-place on non-view
for op in (foo, bar, baz):
x = torch.randn(3)
y = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
z = x.clone()
op(z, y)
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
# Test in-place on view
for op in (foo, bar, baz):
x = torch.randn(3)
y = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
z = x[:]
op(z, y)
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
@parametrize("mode", ("nothing", "warn"))
def test_post_autograd_returns_leaf(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a) -> (Tensor, Tensor)")
op = self.get_op("foo")
lib.impl(
"foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
)
x = torch.randn(3, requires_grad=True)
y, z = op(x)
with self._check_ctx(mode):
z.sum().backward()
@parametrize("mode", ("nothing", "warn"))
def test_undefined_inputs_outputs(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
op = self.get_op("foo")
def foo_impl(a, b):
return None, b.clone()
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3, requires_grad=True)
# NB: PyTorch dispatcher treats "None" as undefined Tensor.
y, z = op(None, x)
with self._check_ctx(mode):
z.sum().backward()
@parametrize("mode", ("nothing", "warn"))
def test_undefined_grads(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
op = self.get_op("foo")
def foo_impl(a, b):
return a.sin(), b.cos()
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3, requires_grad=True)
y = torch.randn(3)
w, z = op(x, y)
w = torch._C._functions.UndefinedGrad()(w)
z = torch._C._functions.UndefinedGrad()(z)
with self._check_ctx(mode):
(z + w).sum().backward()
@parametrize("mode", ("nothing", "warn"))
def test_base_does_not_require_grad(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
op = self.get_op("foo")
def foo_impl(a):
with torch.no_grad():
return a.zero_()
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3)
y = x[:]
y.requires_grad_()
w = y[:]
self.assertTrue(w._base is x)
# Hook should be registered on w, but not w._base
op(w)
with self._check_ctx(mode):
w.sum().backward()
@parametrize("mode", ("nothing", "warn"))
def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
op = self.get_op("foo")
def foo_impl(a, b):
with torch.no_grad():
x = a.clone()
z = b.clone()
y = a * b
return x, y, z
lib.impl("foo", foo_impl, "CPU")
a = torch.randn(3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x, y, z = op(a, b)
with self._check_ctx(mode, mode_nothing_raises=True):
torch.autograd.grad(
x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
)
with self._check_ctx(mode, mode_nothing_raises=False):
torch.autograd.grad(
y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
)
with self._check_ctx(mode, mode_nothing_raises=True):
torch.autograd.grad(
z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
)
@parametrize("mode", ("nothing", "warn"))
def test_supports_tensor_lists(self, mode):
with autograd_fallback_mode(mode):
lib = self.get_lib()
lib.define("foo(Tensor[] a) -> Tensor[]")
op = self.get_op("foo")
def foo_impl(a):
x, y, z = a
with torch.no_grad():
return x + y + z, x * y * z
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3, requires_grad=True)
y = torch.randn(1, requires_grad=True)
z = torch.randn(2, 1, requires_grad=True)
a, b = op([x, y, z])
with self._check_ctx(mode, mode_nothing_raises=True):
torch.autograd.grad(
a,
(x, y, z),
torch.ones_like(a),
allow_unused=True,
retain_graph=True,
)
with self._check_ctx(mode, mode_nothing_raises=True):
torch.autograd.grad(
b,
(x, y, z),
torch.ones_like(b),
allow_unused=True,
retain_graph=True,
)
instantiate_parametrized_tests(TestAutogradFallback)
if __name__ == "__main__":
run_tests()