blob: 0198e264d3f0cd736592617d5bb4de3d548fe29f [file] [log] [blame]
# Owner(s): ["module: custom-operators"]
from torch.testing._internal.common_utils import * # noqa: F403
from torch.testing._internal.common_device_type import * # noqa: F403
from torch.testing._internal.optests.compile_check import operator_compile_check
from torch.testing._internal.custom_op_db import custom_op_db
from torch._custom_op.impl import custom_op
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
class TestCustomOpTesting(TestCase):
def setUp(self):
self.test_ns = '_test_custom_op'
self.libraries = []
def tearDown(self):
import torch._custom_op
keys = list(torch._custom_op.impl.global_registry.keys())
for key in keys:
if not key.startswith(f'{self.test_ns}::'):
continue
torch._custom_op.impl.global_registry[key]._destroy()
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_custom_op
for lib in self.libraries:
del lib.m
del self.libraries
def ns(self):
return getattr(torch.ops, self.test_ns)
def lib(self):
result = torch.library.Library(self.test_ns, 'FRAGMENT')
self.libraries.append(result)
return result
def test_incorrect_schema_mutation(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
guard = torch._C._AutoDispatchBelowAutograd()
try:
return op(x)
finally:
del guard
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
x.sin_()
return x.clone()
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
def f(x):
x = x.clone()
v = x.view_as(x)
y = op(v)
return x
x = torch.tensor(3.14159 / 3, requires_grad=True, device=device)
with self.assertRaisesRegex(
RuntimeError,
'Argument x is not defined as mutable but was mutated'):
operator_compile_check(f, (x,), {})
def test_incorrect_schema_view(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
with torch._C._AutoDispatchBelowAutograd():
with torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)):
return op(x)
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x.view_as(x)
def foo_meta(x):
return x.view_as(x)
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_meta, "Meta")
def f(x):
x = x.clone()
y = op(x)
x.sin_()
return y
x = torch.tensor(3.14159 / 3, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Argument x is not defined to alias output but was aliasing'):
operator_compile_check(f, (x,), {})
def test_missing_abstract_impl(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch._C._AutoDispatchBelowAutograd():
return op(x)
@staticmethod
def backward(ctx, gx):
return 2 * gx
def foo_impl(x):
return torch.tensor(x.cpu().numpy() ** 2, device=x.device)
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
def f(x):
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
torch._subclasses.fake_tensor.UnsupportedOperatorException,
'_test_custom_op.foo.default'):
operator_compile_check(f, (x,), {})
def test_incorrect_abstract_impl(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
guard = torch._C._AutoDispatchBelowAutograd()
guard2 = torch._C.ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView))
try:
return op(x)
finally:
del guard
del guard2
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x ** 2
def foo_meta(x):
return x.unsqueeze(1) ** 2
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
lib.impl("foo", foo_meta, "Meta")
def f(x):
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Shapes .* are not equal'):
operator_compile_check(f, (x,), {})
def test_missing_functionalization(self, device):
lib = self.lib()
lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
with torch._C._AutoDispatchBelowAutograd():
return op(x)
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x.sin_()
def foo_meta(x):
return x
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_impl, "CUDA")
lib.impl("foo", foo_meta, "Meta")
def f(x):
x = x.clone()
y = op(x)
return y.sum(0)
x = torch.tensor([0, 1.], requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
'Getting these operators to work with functionalization requires some extra work'):
operator_compile_check(f, (x,), {})
def test_autograd_registered_at_backend(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, gx):
return gx * 0.5
lib.impl("foo", Foo.apply, "CPU")
lib.impl("foo", Foo.apply, "CUDA")
lib.impl("foo", lambda x: x.clone(), "Meta")
def f(x):
y = op(x)
return x + y
x = torch.randn([], requires_grad=True)
with self.assertRaisesRegex(AssertionError, 'mismatched requires_grad-ness'):
operator_compile_check(f, (x,), {})
# I'm not sure why this is necessary
del lib
def test_global_state_mutation(self, device):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
op = self.ns().foo.default
class Foo(torch.autograd.Function):
invoked = 0
@staticmethod
def forward(ctx, x):
Foo.invoked += 1
return x.clone() * Foo.invoked
@staticmethod
def backward(ctx, gx):
return gx
lib.impl("foo", Foo.apply, "CompositeImplicitAutograd")
def f(x):
return op(x)
x = torch.tensor(3.14159 / 3, requires_grad=True)
with self.assertRaisesRegex(AssertionError, "not completely traceable"):
operator_compile_check(f, (x,), {})
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_operator_compile_check_op(self, device, dtype, op):
for sample_input in op.sample_inputs(device, dtype, requires_grad=op.supports_autograd):
dynamic_only = op.name in ("NumpyNMSCustomOp", "NumpyNonzeroCustomOp")
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
operator_compile_check(
op.op, args, kwargs,
supports_autograd=op.supports_autograd,
dynamic_only=dynamic_only,
fullgraph=False, # Dynamo graph breaks on CustomOp today
)
def test_operator_compile_check_fails_basic(self, device):
@custom_op(f'{self.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@foo.impl(['cpu', 'cuda'])
def foo_impl(x):
return x.sum()
x = torch.randn(3, device=device, requires_grad=True)
# Triggers the CustomOp autograd NYI error
with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented for operator"):
operator_compile_check(lambda x: foo(x), (x,), {})
def test_assert_raises_regex(self, device):
from torch.testing._internal.optests.aot_autograd import assert_raises_regex
with assert_raises_regex(RuntimeError, 'c'):
raise RuntimeError("abcd")
with assert_raises_regex(RuntimeError, 'c.*'):
raise RuntimeError("abcd")
with self.assertRaisesRegex(AssertionError, 'instead got'):
with assert_raises_regex(RuntimeError, 'c.*'):
raise ValueError("abcd")
with self.assertRaisesRegex(AssertionError, 'Expected exception'):
with assert_raises_regex(RuntimeError, 'c.*'):
pass
with self.assertRaisesRegex(AssertionError, 'to match regex'):
with assert_raises_regex(RuntimeError, 'f'):
raise RuntimeError("abcd")
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()