| # Owner(s): ["module: custom-operators"] |
| |
| from torch.testing._internal.common_utils import * # noqa: F403 |
| from torch.testing._internal.common_device_type import * # noqa: F403 |
| import collections |
| |
| import itertools |
| import re |
| import typing |
| |
| import torch._custom_ops as custom_ops |
| |
| import torch.testing._internal.custom_op_db |
| import torch.testing._internal.optests as optests |
| from functorch import make_fx |
| from torch import Tensor |
| from torch._custom_op.impl import custom_op, CustomOp |
| from torch.testing._internal.custom_op_db import custom_op_db |
| from torch.testing._internal.optests.compile_check import operator_compile_check |
| from typing import * # noqa: F403 |
| |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") |
| class TestCustomOpTesting(TestCase): |
| def setUp(self): |
| self.test_ns = "_test_custom_op" |
| self.libraries = [] |
| |
| def tearDown(self): |
| import torch._custom_op |
| |
| keys = list(torch._custom_op.impl.global_registry.keys()) |
| for key in keys: |
| if not key.startswith(f"{self.test_ns}::"): |
| continue |
| torch._custom_op.impl.global_registry[key]._destroy() |
| if hasattr(torch.ops, self.test_ns): |
| del torch.ops._test_custom_op |
| for lib in self.libraries: |
| del lib.m |
| del self.libraries |
| |
| def ns(self): |
| return getattr(torch.ops, self.test_ns) |
| |
| def lib(self): |
| result = torch.library.Library(self.test_ns, "FRAGMENT") |
| self.libraries.append(result) |
| return result |
| |
| def get_op(self, qualname): |
| ns, name = qualname.split("::") |
| return getattr(getattr(torch.ops, ns), name).default |
| |
| @parametrize("check_gradients", (False, "auto")) |
| @parametrize("dynamic", (True, False)) |
| def test_aot_autograd_check_degenerate_cases( |
| self, device, dynamic, check_gradients |
| ): |
| def simple(x): |
| return x.clone() |
| |
| # Should not raise |
| x = torch.randn(3, device=device) |
| optests.aot_autograd_check( |
| simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def outputs_dont_require_grad(x): |
| return x.detach() |
| |
| # Should not raise |
| y = torch.randn(3, device=device, requires_grad=True) |
| optests.aot_autograd_check( |
| simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def no_outputs(x): |
| return x.detach() |
| |
| # Should not raise |
| x = torch.randn(3, device=device, requires_grad=True) |
| y = torch.randn(3, device=device, requires_grad=False) |
| optests.aot_autograd_check( |
| no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| optests.aot_autograd_check( |
| no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def test_incorrect_schema_mutation(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| guard = torch._C._AutoDispatchBelowAutograd() |
| try: |
| return op(x) |
| finally: |
| del guard |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| x.sin_() |
| return x.clone() |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| def f(x): |
| x = x.clone() |
| v = x.view_as(x) |
| y = op(v) |
| return x |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) |
| with self.assertRaisesRegex( |
| RuntimeError, "Argument x is not defined as mutable but was mutated" |
| ): |
| operator_compile_check(f, (x,), {}) |
| |
| def test_incorrect_schema_view(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python |
| with torch._C._AutoDispatchBelowAutograd(): |
| with torch._C._ExcludeDispatchKeyGuard( |
| torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) |
| ): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.view_as(x) |
| |
| def foo_meta(x): |
| return x.view_as(x) |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| def f(x): |
| x = x.clone() |
| y = op(x) |
| x.sin_() |
| return y |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True) |
| with self.assertRaisesRegex( |
| RuntimeError, "Argument x is not defined to alias output but was aliasing" |
| ): |
| operator_compile_check(f, (x,), {}) |
| |
| def test_missing_abstract_impl(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return 2 * gx |
| |
| def foo_impl(x): |
| return torch.tensor(x.cpu().numpy() ** 2, device=x.device) |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| def f(x): |
| y = op(x) |
| return y.sum(0) |
| |
| x = torch.tensor([0, 1.0], requires_grad=True) |
| with self.assertRaisesRegex( |
| torch._subclasses.fake_tensor.UnsupportedOperatorException, |
| "_test_custom_op.foo.default", |
| ): |
| operator_compile_check(f, (x,), {}) |
| |
| def test_incorrect_abstract_impl(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python |
| guard = torch._C._AutoDispatchBelowAutograd() |
| guard2 = torch._C.ExcludeDispatchKeyGuard( |
| torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) |
| ) |
| try: |
| return op(x) |
| finally: |
| del guard |
| del guard2 |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x**2 |
| |
| def foo_meta(x): |
| return x.unsqueeze(1) ** 2 |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| def f(x): |
| y = op(x) |
| return y.sum(0) |
| |
| x = torch.tensor([0, 1.0], requires_grad=True) |
| with self.assertRaisesRegex(RuntimeError, "Shapes .* are not equal"): |
| operator_compile_check(f, (x,), {}) |
| |
| def test_missing_functionalization(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor(a!) x) -> Tensor(a!)") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.mark_dirty(x) |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.sin_() |
| |
| def foo_meta(x): |
| return x |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| def f(x): |
| x = x.clone() |
| y = op(x) |
| return y.sum(0) |
| |
| x = torch.tensor([0, 1.0], requires_grad=True) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Getting these operators to work with functionalization requires some extra work", |
| ): |
| operator_compile_check(f, (x,), {}) |
| |
| def test_autograd_registered_at_backend(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx * 0.5 |
| |
| lib.impl("foo", Foo.apply, "CPU") |
| lib.impl("foo", Foo.apply, "CUDA") |
| lib.impl("foo", lambda x: x.clone(), "Meta") |
| |
| def f(x): |
| y = op(x) |
| return x + y |
| |
| x = torch.randn([], requires_grad=True) |
| |
| with self.assertRaisesRegex(AssertionError, "mismatched requires_grad-ness"): |
| operator_compile_check(f, (x,), {}) |
| |
| # I'm not sure why this is necessary |
| del lib |
| |
| def test_global_state_mutation(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| invoked = 0 |
| |
| @staticmethod |
| def forward(ctx, x): |
| Foo.invoked += 1 |
| return x.clone() * Foo.invoked |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| lib.impl("foo", Foo.apply, "CompositeImplicitAutograd") |
| |
| def f(x): |
| return op(x) |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True) |
| with self.assertRaisesRegex(AssertionError, "not completely traceable"): |
| operator_compile_check(f, (x,), {}) |
| |
| @ops(custom_op_db, dtypes=OpDTypes.any_one) |
| def test_operator_compile_check_op(self, device, dtype, op): |
| for sample_input in op.sample_inputs( |
| device, dtype, requires_grad=op.supports_autograd |
| ): |
| dynamic_only = op.name in ("NumpyNMSCustomOp", "NumpyNonzeroCustomOp") |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| operator_compile_check( |
| op.op, |
| args, |
| kwargs, |
| supports_autograd=op.supports_autograd, |
| dynamic_only=dynamic_only, |
| fullgraph=False, # Dynamo graph breaks on CustomOp today |
| ) |
| |
| def test_operator_compile_check_fails_basic(self, device): |
| @custom_op(f"{self.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| ... |
| |
| @foo.impl(["cpu", "cuda"]) |
| def foo_impl(x): |
| return x.sum() |
| |
| x = torch.randn(3, device=device, requires_grad=True) |
| # Triggers the CustomOp autograd NYI error |
| with self.assertRaisesRegex( |
| RuntimeError, "Autograd has not been implemented for operator" |
| ): |
| operator_compile_check( |
| lambda x: self.get_op(f"{self.test_ns}::foo")(x), (x,), {} |
| ) |
| |
| def test_autograd_registration_check_autograd_kernel(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| # Should not raise |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_compositeimplicitautograd(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| def foo_impl(x): |
| return x.sin().cos() |
| |
| lib.impl("foo", foo_impl, "CompositeImplicitAutograd") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| # Should not raise |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_incorrect_composite(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| def foo_impl(x): |
| return x.sin().cos() |
| |
| lib.impl("foo", foo_impl, "CompositeExplicitAutograd") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| with self.assertRaisesRegex(AssertionError, "incorrectly registered"): |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_incorrect(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return torch.sin(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| lib.impl("foo", Foo.apply, "CPU") |
| lib.impl("foo", Foo.apply, "CUDA") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| with self.assertRaisesRegex(AssertionError, "incorrectly registered"): |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_assert_raises_regex(self, device): |
| from torch.testing._internal.optests.aot_autograd import assert_raises_regex |
| |
| with assert_raises_regex(RuntimeError, "c"): |
| raise RuntimeError("abcd") |
| with assert_raises_regex(RuntimeError, "c.*"): |
| raise RuntimeError("abcd") |
| with self.assertRaisesRegex(AssertionError, "instead got"): |
| with assert_raises_regex(RuntimeError, "c.*"): |
| raise ValueError("abcd") |
| with self.assertRaisesRegex(AssertionError, "Expected exception"): |
| with assert_raises_regex(RuntimeError, "c.*"): |
| pass |
| with self.assertRaisesRegex(AssertionError, "to match regex"): |
| with assert_raises_regex(RuntimeError, "f"): |
| raise RuntimeError("abcd") |
| |
| |
| class TestCustomOp(TestCase): |
| test_ns = "_test_custom_op" |
| |
| def tearDown(self): |
| import torch._custom_op |
| |
| keys = list(torch._custom_op.impl.global_registry.keys()) |
| for key in keys: |
| if not key.startswith(f"{TestCustomOp.test_ns}::"): |
| continue |
| torch._custom_op.impl.global_registry[key]._destroy() |
| |
| def get_op(self, qualname): |
| ns, name = qualname.split("::") |
| return getattr(getattr(torch.ops, ns), name).default |
| |
| def test_invalid_schemas(self): |
| # function schmea validation goes through torchgen, so this is just a |
| # basic test. |
| with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") |
| |
| def test_name_must_match(self): |
| with self.assertRaisesRegex(ValueError, "to have name"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def baz(x: Tensor) -> Tensor: |
| raise NotImplementedError() |
| |
| def test_unsupported_schemas(self): |
| with self.assertRaisesRegex(ValueError, "does not support non-functional"): |
| custom_ops.custom_op( |
| f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" |
| )(foo) |
| with self.assertRaisesRegex(ValueError, "does not support view functions"): |
| custom_ops.custom_op( |
| f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" |
| )(foo) |
| with self.assertRaisesRegex(ValueError, "no outputs"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( |
| foo |
| ) |
| with self.assertRaisesRegex(ValueError, "self"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")( |
| foo |
| ) |
| |
| # Tests for the older custom_op API |
| def test_schema_matches_signature(self): |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor") |
| def blah(x): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah2(x, y): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah3", |
| "(Tensor x, *, Tensor w, Tensor z) -> Tensor", |
| ) |
| def blah3(x, *, y, z): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah4", |
| "(Tensor x, *, Tensor z, Tensor y) -> Tensor", |
| ) |
| def blah4(x, *, y, z): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "not supported"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor") |
| def blah5(*args): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "not supported"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor" |
| ) |
| def blah6(**kwargs): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "default arguments"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah7(x=1, *, y): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "default arguments"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah8(x, *, y=1): |
| pass |
| |
| # kwonly-arg works |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah9(x, *, y): |
| pass |
| |
| # Tests for the older custom_op API |
| def test_unsupported_annotation_categories(self): |
| with self.assertRaisesRegex(ValueError, "varargs"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(*args): |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "varkwargs"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(**kwargs): |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "must have a type annotation"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x): |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "default value"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Optional[Tensor] = None): |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "default value"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Optional[Tensor] = None): |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor) -> int: |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor) -> Tuple[Tensor, int]: |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "either Tensor or a Tuple"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor) -> Tuple[Tensor, ...]: |
| raise NotImplementedError() |
| |
| del foo |
| |
| def test_supported_param_types(self): |
| def generate_examples(typ): |
| if typ is int: |
| return [17] |
| if typ is float: |
| return [3.14] |
| if typ is bool: |
| return [True] |
| if typ is str: |
| return ["foo"] |
| if typ is torch.dtype: |
| return [torch.float32] |
| if typ is torch.device: |
| return [torch.device("cpu")] |
| if typ == torch.types.Number: |
| return [2.718] |
| if typ is torch.Tensor: |
| return [torch.tensor(3)] |
| if typ == Optional[torch.types.Number]: |
| return [None, 2.718] |
| origin = typing.get_origin(typ) |
| if origin is Union: |
| args = typing.get_args(typ) |
| assert len(args) == 2 and ( |
| args[0] is type(None) or args[1] is type(None) |
| ) |
| elt = args[0] if args[1] is type(None) else args[1] |
| return generate_examples(elt) + [None] |
| if origin is collections.abc.Sequence: |
| args = typing.get_args(typ) |
| assert len(args) == 1 |
| examples = generate_examples(args[0]) |
| return list(itertools.product(examples, examples)) + [] |
| raise AssertionError(f"unsupported param type {typ}") |
| |
| for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES: |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: typ) -> Tensor: |
| raise NotImplementedError() |
| |
| yeet = None |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"]) |
| def foo_cpu(x, y): |
| nonlocal yeet |
| yeet = y |
| return x.clone() |
| |
| try: |
| for example in generate_examples(typ): |
| op = self.get_op(f"{self.test_ns}::foo") |
| op(torch.randn([]), example) |
| self.assertEqual(yeet, example, msg=f"{typ} {example}") |
| yeet = None |
| finally: |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_sequences(self): |
| # Sequence[int] gets automagically turned into int[] in the schema. |
| # This test checks that we actually do support arbitrary sequence types. |
| class MySequence(collections.abc.Sequence): |
| def __init__(self): |
| self._container = [1, 2, 3] |
| |
| def __getitem__(self, idx): |
| return self._container[idx] |
| |
| def __len__(self): |
| return len(self._container) |
| |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| called = 0 |
| |
| @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x, sizes): |
| nonlocal called |
| called += 1 |
| # Dispatcher will normalize the sequence type into a List |
| self.assertEqual(sizes, [1, 2, 3]) |
| return x.clone() |
| |
| x = torch.randn([]) |
| seq = MySequence() |
| op = self.get_op(f"{self.test_ns}::foo") |
| op(x, seq) |
| self.assertEqual(called, 1) |
| |
| def test_unsupported_param_types(self): |
| # Not comprehensive (it doesn't need to be), just a check that our mechanism works |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: List[Optional[int]]) -> Tensor: |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| # int[N] in Dispatcher is a bit wild, so we don't try to support it. |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| # We could theoretically support this, but the syntax for suporting |
| # int[] is Sequence[int] |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: List[int]) -> Tensor: |
| raise NotImplementedError() |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: Callable) -> Tensor: |
| raise NotImplementedError() |
| |
| del foo |
| |
| def test_supported_schemas(self): |
| # All of these should already be tested by PyTorch codegen |
| # (we share the same mechanism), but here's a sanity check. |
| schemas = [ |
| "(Tensor x) -> Tensor", |
| "(Tensor x) -> Tensor y", |
| "(Tensor[] x) -> Tensor y", |
| "(Tensor x) -> (Tensor, Tensor)", |
| "(Tensor x) -> (Tensor y, Tensor z)", |
| "(Tensor x) -> (Tensor y, Tensor z)", |
| ] |
| other_schemas = [ |
| "(Tensor x, Tensor w) -> (Tensor y, Tensor z)", |
| "(Tensor x, Tensor w) -> (Tensor, Tensor)", |
| "(Tensor x, Tensor w) -> Tensor", |
| "(Tensor? x, Tensor w) -> Tensor", |
| "(Tensor? x, Tensor[] w) -> Tensor", |
| "(Tensor x, int[] w) -> Tensor", |
| "(Tensor x, SymInt[] w) -> Tensor", |
| "(Tensor x, Scalar w) -> Tensor", |
| "(Tensor x, float w) -> Tensor", |
| "(Tensor x, float? w) -> Tensor", |
| "(Tensor x, bool[] w) -> Tensor", |
| ] |
| |
| for schema in schemas: |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| for schema in other_schemas: |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::bar") |
| |
| def test_reserved_ns(self): |
| from torch._custom_op.impl import RESERVED_NS |
| |
| for ns in RESERVED_NS: |
| with self.assertRaisesRegex(ValueError, "is a reserved namespace"): |
| custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor") |
| |
| with self.assertRaisesRegex(ValueError, "is a reserved namespace"): |
| |
| @custom_ops.custom_op(f"{ns}::foo2") |
| def foo2(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| def test_private_ctor(self): |
| with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"): |
| CustomOp(None, None, None, None, None) |
| |
| def test_lifetime(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo") |
| |
| # We can't define an op multiple times, |
| with self.assertRaisesRegex(RuntimeError, "multiple times"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError() |
| |
| # Unless we delete the original op. |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| # Smoke test |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError() |
| |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_autograd_notimplemented(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError() |
| |
| x = torch.randn(3, requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op(x) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| del foo |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op([y, x]) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| del foo |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op(y, x) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_autograd_notimplemented_gradmode(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x * y |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with torch.no_grad(): |
| # Shouldn't raise, because we are in no_grad |
| op(y, x) |
| |
| def test_impl_cpu(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x): |
| return x.sin() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x) |
| self.assertEqual(result, foo_cpu(x)) |
| |
| def test_impl_invalid_devices(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY |
| |
| for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): |
| # Smoke test: should not raise error |
| custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( |
| foo_impl |
| ) |
| |
| # Not supported by this API: we can either support them in the future |
| # or provide some other CustomOp.def_* function. This depends on how |
| # common the use cases are. |
| for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]: |
| with self.assertRaisesRegex(ValueError, "we only support device_type"): |
| custom_ops.impl( |
| f"{TestCustomOp.test_ns}::foo", device_types=invalid_type |
| )(foo_impl) |
| |
| def test_backward_partially_registered(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return grad * saved.cos() |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex( |
| RuntimeError, "unable to find a 'save_for_backward'" |
| ): |
| y = op(x) |
| y.backward() |
| |
| def test_save_for_backward_inputs_are_namedtuple(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| hit = 0 |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| nonlocal hit |
| hit += 1 |
| self.assertTrue(isinstance(inputs, tuple)) |
| self.assertEqual(list(inputs._asdict().keys()), ["x"]) |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| self.assertEqual(hit, 1) |
| y.backward() |
| self.assertEqual(hit, 1) |
| |
| def test_backward_returns_dict(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return grad * saved.cos() |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "to be a dict"): |
| y.backward() |
| |
| def test_backward_dict_invalid_keys(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos(), "y": None} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"): |
| y.backward() |
| |
| def test_backward_dict_grad_for_nontensor(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, dim): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos(), "dim": None} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, 32) |
| with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"): |
| y.backward() |
| |
| def test_backward_dict_requires_keys_for_input_tensors(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, x) |
| with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): |
| y.backward() |
| |
| def test_backward_dict_requires_keys_for_input_optional_tensors(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, None) |
| with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): |
| y.backward() |
| |
| def test_backward_grads_are_tensor_or_none(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": (grad * saved.cos(),)} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": [grad * saved.cos(), None]} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": [grad * saved.cos(), None, (None,)]} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "None or Tensor"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "list of gradients"): |
| y.backward() |
| |
| def test_backward_output_differentiability_type(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| with self.assertRaisesRegex(RuntimeError, "output_differentiability"): |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=True |
| ) |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| def test_backward_output_differentiability_numel(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: |
| raise NotImplementedError() |
| |
| with self.assertRaisesRegex(RuntimeError, "output_differentiability"): |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=[True] |
| ) |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_impl_separate(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x): |
| return x.sin() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") |
| def foo_cuda(x): |
| return x.cos() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x) |
| self.assertEqual(result, foo_cpu(x)) |
| |
| x_cuda = x.cuda() |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x_cuda) |
| self.assertEqual(result, foo_cuda(x_cuda)) |
| |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_impl_multiple(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.cos() |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| x = torch.randn(3) |
| result = op(x) |
| self.assertEqual(result, foo_impl(x)) |
| |
| x_cuda = x.cuda() |
| result = op(x_cuda) |
| self.assertEqual(result, foo_impl(x_cuda)) |
| |
| def test_impl_meta(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo") |
| def foo_meta(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| x = torch.randn(2, 3, device="meta") |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x, 1) |
| self.assertEqual(result.shape, foo_meta(x, 1).shape) |
| |
| def test_duplicate_impl(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo") |
| def foo_meta(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, r"already has a abstract impl.*at .*test_custom_ops.py:\d+" |
| ): |
| |
| @custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo") |
| def foo_meta2(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| def test_new_data_dependent_symint(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo") |
| def foo_meta(x): |
| ctx = torch._custom_op.impl.get_ctx() |
| with self.assertRaisesRegex(ValueError, "greater than or equal to 2"): |
| ctx.create_unbacked_symint(min=1) |
| with self.assertRaisesRegex(ValueError, "greater than or equal to 2"): |
| ctx.create_unbacked_symint(min=-1) |
| with self.assertRaisesRegex(ValueError, "SymInt"): |
| ctx.create_unbacked_symint(max=x.numel()) |
| return torch.clone(x) |
| |
| x = torch.randn(2, 3, device="cpu") |
| op = self.get_op(f"{self.test_ns}::foo") |
| make_fx(op, tracing_mode="symbolic")(x) |
| |
| def test_meta_for_data_dependent_shape_operation(self): |
| x = torch.randn(10, device="meta") |
| with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"): |
| torch.ops._torch_testing.numpy_nonzero(x) |
| |
| def test_basic_make_fx(self): |
| # More serious tests are in our CustomOp opinfo db, |
| # this one is just a sanity check. |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| @custom_ops.impl_abstract(f"{TestCustomOp.test_ns}::foo") |
| def foo_meta(x): |
| return x.sum() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| gm = make_fx(op, tracing_mode="symbolic")(x) |
| self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code) |
| |
| def test_not_implemented_error(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"): |
| op(x) |
| |
| x = torch.randn(3, device="meta") |
| with self.assertRaisesRegex(NotImplementedError, "abstract impl registered"): |
| op(x) |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar") |
| def bar(sizes: Sequence[int]) -> torch.Tensor: |
| raise NotImplementedError() |
| |
| op = self.get_op(f"{self.test_ns}::bar") |
| with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"): |
| op((1, 2, 3)) |
| |
| def test_abstract_registration_location(self): |
| custom_op = torch._custom_op.impl._find_custom_op( |
| "_torch_testing::numpy_nonzero" |
| ) |
| loc = custom_op._get_impl("abstract").location |
| matches = re.match(r".*custom_op_db.py:\d+", loc) |
| self.assertIsNotNone(matches) |
| |
| def test_data_dependent_basic(self): |
| def f(x): |
| return torch.ops._torch_testing.numpy_nonzero(x) |
| |
| x = torch.randn(5, 5) |
| gm = make_fx(f, tracing_mode="symbolic")(x) |
| self.assertTrue("nonzero" in gm.code) |
| |
| def test_data_dependent_fake_tracing(self): |
| def f(x): |
| return torch.ops._torch_testing.numpy_nonzero(x) |
| |
| x = torch.randn(5, 5) |
| with self.assertRaises( |
| torch._subclasses.fake_tensor.DynamicOutputShapeException |
| ): |
| make_fx(f, tracing_mode="fake")(x) |
| |
| def test_symints(self): |
| def f(x): |
| return torch.ops._torch_testing.numpy_view_copy(x, x.shape) |
| |
| x = torch.randn(2, 3, 4) |
| gm = make_fx(f, tracing_mode="symbolic")(x) |
| result = gm(x) |
| self.assertEqual(result, f(x)) |
| self.assertExpectedInline( |
| gm.code.strip(), |
| """\ |
| def forward(self, x_1): |
| sym_size = torch.ops.aten.sym_size(x_1, 0) |
| sym_size_1 = torch.ops.aten.sym_size(x_1, 1) |
| sym_size_2 = torch.ops.aten.sym_size(x_1, 2) |
| numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size, sym_size_1, sym_size_2]); x_1 = sym_size = sym_size_1 = sym_size_2 = None |
| return numpy_view_copy""", # noqa: B950 |
| ) |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") |
| def test_data_dependent_compile(self): |
| import torch._dynamo.testing |
| from torch._dynamo.utils import counters |
| |
| counters.clear() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt) |
| def f(x): |
| return torch.ops._torch_testing.numpy_nonzero(x.clone()).clone() |
| |
| f(torch.randn(10)) |
| |
| self.assertEqual( |
| dict(counters["graph_break"]), |
| {"dynamic shape operator: _torch_testing.numpy_nonzero.default": 1}, |
| ) |
| |
| # pre-existing problem: torch.compile(dynamic=True) will, by default, |
| # graph break on data-dependent operations. Eventually we'll make it so |
| # that it never graph breaks on data-dependent operations. |
| @unittest.expectedFailure |
| def test_data_dependent_nms_dynamic_compile(self): |
| import torch._dynamo.testing |
| from torch._dynamo.utils import counters |
| |
| counters.clear() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt, dynamic=True) |
| def f(x, s, i): |
| return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone() |
| |
| f(torch.randn(20, 4), torch.randn(20), 0.1) |
| |
| self.assertEqual(len(counters["graph_break"]), 0) |
| |
| |
| only_for = ("cpu", "cuda") |
| instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) |
| |
| if __name__ == "__main__": |
| run_tests() |