| # Owner(s): ["module: custom-operators"] |
| |
| import collections |
| import itertools |
| import os |
| import re |
| import subprocess |
| import sys |
| import typing |
| import unittest |
| from typing import * # noqa: F403 |
| |
| import numpy as np |
| |
| import torch._custom_ops as custom_ops |
| import torch.testing._internal.optests as optests |
| import torch.utils._pytree as pytree |
| import torch.utils.cpp_extension |
| from functorch import make_fx |
| from torch import Tensor |
| from torch._custom_op.impl import CustomOp, infer_schema |
| from torch._library.infer_schema import tuple_to_list |
| from torch._utils_internal import get_file_path_2 |
| from torch.testing._internal import custom_op_db |
| from torch.testing._internal.common_cuda import TEST_CUDA |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| OpDTypes, |
| ops, |
| ) |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| IS_WINDOWS, |
| parametrize, |
| run_tests, |
| skipIfTorchDynamo, |
| subtest, |
| TestCase, |
| ) |
| from torch.testing._internal.custom_op_db import numpy_nonzero |
| |
| |
| # Shadowed by `torch.testing._internal.common_utils.custom_op` |
| from torch._custom_op.impl import custom_op # usort: skip |
| |
| |
| def requires_compile(fun): |
| fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun) |
| return fun |
| |
| |
| class CustomOpTestCaseBase(TestCase): |
| test_ns = "_test_custom_op" |
| |
| def setUp(self): |
| super().setUp() |
| self.libraries = [] |
| |
| def tearDown(self): |
| super().tearDown() |
| import torch._custom_op |
| |
| keys = list(torch._custom_op.impl.global_registry.keys()) |
| for key in keys: |
| if not key.startswith(f"{self.test_ns}::"): |
| continue |
| torch._custom_op.impl.global_registry[key]._destroy() |
| if hasattr(torch.ops, self.test_ns): |
| delattr(torch.ops, self.test_ns) |
| for lib in self.libraries: |
| lib._destroy() |
| del self.libraries |
| |
| def ns(self): |
| return getattr(torch.ops, self.test_ns) |
| |
| def lib(self): |
| result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 |
| self.libraries.append(result) |
| return result |
| |
| def get_op(self, qualname): |
| return torch._custom_op.impl.get_op(qualname) |
| |
| |
| @requires_compile |
| class TestCustomOpTesting(CustomOpTestCaseBase): |
| @parametrize("check_gradients", (False, "auto")) |
| @parametrize("dynamic", (True, False)) |
| def test_aot_autograd_check_degenerate_cases( |
| self, device, dynamic, check_gradients |
| ): |
| def simple(x): |
| return x.clone() |
| |
| # Should not raise |
| x = torch.randn(3, device=device) |
| optests.aot_autograd_check( |
| simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def outputs_dont_require_grad(x): |
| return x.detach() |
| |
| # Should not raise |
| y = torch.randn(3, device=device, requires_grad=True) |
| optests.aot_autograd_check( |
| simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def no_outputs(x): |
| return x.detach() |
| |
| # Should not raise |
| x = torch.randn(3, device=device, requires_grad=True) |
| y = torch.randn(3, device=device, requires_grad=False) |
| optests.aot_autograd_check( |
| no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| optests.aot_autograd_check( |
| no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients |
| ) |
| |
| def test_incorrect_schema_mutation(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| guard = torch._C._AutoDispatchBelowAutograd() |
| try: |
| return op(x) |
| finally: |
| del guard |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| x.sin_() |
| return x.clone() |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) |
| with self.assertRaisesRegex( |
| optests.OpCheckError, "Argument x is not defined as mutable but was mutated" |
| ): |
| torch.library.opcheck(op, (x,), {}) |
| |
| def test_incorrect_schema_view(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python |
| with torch._C._AutoDispatchBelowAutograd(): |
| with torch._C._ExcludeDispatchKeyGuard( |
| torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) |
| ): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.view_as(x) |
| |
| def foo_meta(x): |
| return x.view_as(x) |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True) |
| with self.assertRaisesRegex( |
| optests.OpCheckError, |
| "Argument x is not defined to alias output but was aliasing", |
| ): |
| torch.library.opcheck(op, (x,), {}) |
| |
| def test_missing_abstract_impl(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return 2 * gx |
| |
| def foo_impl(x): |
| return torch.tensor(x.cpu().numpy() ** 2, device=x.device) |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| x = torch.tensor([0, 1.0], requires_grad=True) |
| with self.assertRaisesRegex( |
| optests.OpCheckError, |
| "_test_custom_op.foo.default", |
| ): |
| torch.library.opcheck(op, (x,), {}) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_incorrect_abstract_impl(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python |
| guard = torch._C._AutoDispatchBelowAutograd() |
| guard2 = torch._C.ExcludeDispatchKeyGuard( |
| torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) |
| ) |
| try: |
| return op(x) |
| finally: |
| del guard |
| del guard2 |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x**2 |
| |
| def foo_meta(x): |
| return x.unsqueeze(1) ** 2 |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| x = torch.tensor([0, 1.0], requires_grad=True) |
| with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"): |
| torch.library.opcheck(op, (x,), {}) |
| |
| def test_missing_functionalization(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor(a!) x) -> Tensor(a!)") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.mark_dirty(x) |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.sin_() |
| |
| def foo_meta(x): |
| return x |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| lib.impl("foo", foo_meta, "Meta") |
| |
| x = torch.tensor([0, 1.0]) |
| y = x.clone() |
| with self.assertRaisesRegex( |
| optests.OpCheckError, |
| "We only support functionalizing operators whose outputs do not have alias annotations", |
| ): |
| torch.library.opcheck(op, (y,), {}) |
| |
| def test_autograd_registered_at_backend(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx * 0.5 |
| |
| lib.impl("foo", Foo.apply, "CPU") |
| lib.impl("foo", Foo.apply, "CUDA") |
| lib.impl("foo", lambda x: x.clone(), "Meta") |
| |
| x = torch.randn([], requires_grad=True) |
| |
| with self.assertRaisesRegex( |
| torch.testing._internal.optests.OpCheckError, |
| "does not have an autograd kernel", |
| ): |
| torch.library.opcheck(op, (x,), {}) |
| |
| # I'm not sure why this is necessary |
| del lib |
| |
| def test_global_state_mutation(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| invoked = 0 |
| |
| @staticmethod |
| def forward(ctx, x): |
| Foo.invoked += 1 |
| return x.clone() * Foo.invoked |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| lib.impl("foo", Foo.apply, "CompositeImplicitAutograd") |
| |
| x = torch.tensor(3.14159 / 3, requires_grad=True) |
| with self.assertRaisesRegex( |
| optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd" |
| ): |
| torch.library.opcheck(op, (x,), {}) |
| |
| @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one) |
| def test_opcheck_opinfo(self, device, dtype, op): |
| for sample_input in op.sample_inputs( |
| device, dtype, requires_grad=op.supports_autograd |
| ): |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| torch.library.opcheck(op.op, args, kwargs) |
| |
| def test_opcheck_fails_basic(self, device): |
| @custom_op(f"{self.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: ... |
| |
| @foo.impl(["cpu", "cuda"]) |
| def foo_impl(x): |
| return x.sum() |
| |
| x = torch.randn(3, device=device, requires_grad=True) |
| # Triggers the CustomOp autograd NYI error |
| with self.assertRaisesRegex( |
| optests.OpCheckError, "Autograd has not been implemented for operator" |
| ): |
| torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {}) |
| |
| def test_autograd_registration_check_autograd_kernel(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", Foo.apply, "Autograd") |
| lib.impl("foo", foo_impl, "CPU") |
| lib.impl("foo", foo_impl, "CUDA") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| # Should not raise |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_compositeimplicitautograd(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| def foo_impl(x): |
| return x.sin().cos() |
| |
| lib.impl("foo", foo_impl, "CompositeImplicitAutograd") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| # Should not raise |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_incorrect_composite(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| def foo_impl(x): |
| return x.sin().cos() |
| |
| lib.impl("foo", foo_impl, "CompositeExplicitAutograd") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| with self.assertRaisesRegex(AssertionError, "incorrectly registered"): |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_autograd_registration_check_incorrect(self, device): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| op = self.ns().foo.default |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return torch.sin(x) |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx |
| |
| lib.impl("foo", Foo.apply, "CPU") |
| lib.impl("foo", Foo.apply, "CUDA") |
| |
| x = torch.randn(3, requires_grad=True, device=device) |
| with self.assertRaisesRegex(AssertionError, "incorrectly registered"): |
| optests.autograd_registration_check(op, (x,), {}) |
| |
| def test_assert_raises_regex(self, device): |
| from torch.testing._internal.optests.aot_autograd import assert_raises_regex |
| |
| with assert_raises_regex(RuntimeError, "c"): |
| raise RuntimeError("abcd") |
| with assert_raises_regex(RuntimeError, "c.*"): |
| raise RuntimeError("abcd") |
| with self.assertRaisesRegex(AssertionError, "instead got"): |
| with assert_raises_regex(RuntimeError, "c.*"): |
| raise ValueError("abcd") |
| with self.assertRaisesRegex(AssertionError, "Expected exception"): |
| with assert_raises_regex(RuntimeError, "c.*"): |
| pass |
| with self.assertRaisesRegex(AssertionError, "to match regex"): |
| with assert_raises_regex(RuntimeError, "f"): |
| raise RuntimeError("abcd") |
| |
| |
| class TestCustomOp(CustomOpTestCaseBase): |
| test_ns = "_test_custom_op" |
| |
| @requires_compile |
| def test_functionalize_error(self): |
| with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: |
| lib.define("foo(Tensor(a!) x) -> Tensor(a!)") |
| |
| def foo(x): |
| return x.sin_() |
| |
| lib.impl("foo", foo, "CompositeExplicitAutograd") |
| foo_op = self.get_op(f"{self.test_ns}::foo") |
| |
| lib.define("bar(Tensor(a) x) -> Tensor(a)") |
| |
| def bar(x): |
| return x.view(-1) |
| |
| lib.impl("bar", bar, "CompositeExplicitAutograd") |
| bar_op = self.get_op(f"{self.test_ns}::bar") |
| |
| msg = r".*We only support functionalizing operators whose outputs do not have alias annotations" |
| |
| x = torch.randn(3) |
| |
| @torch.compile(backend="aot_eager", fullgraph=True) |
| def f(x): |
| return foo_op(x) |
| |
| @torch.compile(backend="aot_eager", fullgraph=True) |
| def g(x): |
| return bar_op(x) |
| |
| with self.assertRaisesRegex(RuntimeError, msg): |
| f(x) |
| with self.assertRaisesRegex(RuntimeError, msg): |
| g(x) |
| |
| def test_invalid_schemas(self): |
| # function schmea validation goes through torchgen, so this is just a |
| # basic test. |
| with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") |
| |
| def test_invalid_qualname(self): |
| with self.assertRaisesRegex(ValueError, "overload"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()") |
| |
| def test_name_must_match(self): |
| with self.assertRaisesRegex(ValueError, "to have name"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def baz(x: Tensor) -> Tensor: |
| raise NotImplementedError |
| |
| def test_unsupported_schemas(self): |
| with self.assertRaisesRegex(ValueError, "only supports functional"): |
| custom_ops.custom_op( |
| f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" |
| )(foo) |
| with self.assertRaisesRegex(ValueError, "only supports functional"): |
| custom_ops.custom_op( |
| f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" |
| )(foo) |
| with self.assertRaisesRegex(ValueError, "only supports functional"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( |
| foo |
| ) |
| with self.assertRaisesRegex(ValueError, "self"): |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")( |
| foo |
| ) |
| |
| # Tests for the older custom_op API |
| def test_schema_matches_signature(self): |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor") |
| def blah(x): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah2(x, y): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah3", |
| "(Tensor x, *, Tensor w, Tensor z) -> Tensor", |
| ) |
| def blah3(x, *, y, z): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "signature to match"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah4", |
| "(Tensor x, *, Tensor z, Tensor y) -> Tensor", |
| ) |
| def blah4(x, *, y, z): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "not supported"): |
| |
| @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor") |
| def blah5(*args): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "not supported"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor" |
| ) |
| def blah6(**kwargs): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "default arguments"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah7(x=1, *, y): |
| pass |
| |
| with self.assertRaisesRegex(ValueError, "default arguments"): |
| |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah8(x, *, y=1): |
| pass |
| |
| # kwonly-arg works |
| @custom_op( |
| f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor" |
| ) |
| def blah9(x, *, y): |
| pass |
| |
| def test_infer_schema_no_return(self): |
| with self.assertRaisesRegex( |
| ValueError, "No return type annotation was provided. Please add one." |
| ): |
| |
| @torch.library.custom_op("mylib::foo", mutates_args={}) |
| def foo(x: torch.Tensor, y: int): |
| return x * y |
| |
| def test_infer_schema_supported(self): |
| def a(x: Tensor) -> Tensor: |
| return torch.empty([]) |
| |
| self.assertExpectedInline( |
| infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor""" |
| ) |
| |
| def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor: |
| return torch.empty([]) |
| |
| self.assertExpectedInline( |
| infer_schema(kwonly1, mutates_args=()), |
| """(Tensor x, *, SymInt y, float z) -> Tensor""", |
| ) |
| |
| def kwonly2(*, y: Tensor) -> Tensor: |
| return torch.empty([]) |
| |
| self.assertExpectedInline( |
| infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor""" |
| ) |
| |
| def b( |
| x: Tensor, |
| y: int, |
| z: bool, |
| a: float, |
| b: torch.dtype, |
| c: torch.device, |
| d: torch.types.Number, |
| ) -> Tuple[Tensor, int, float, bool]: |
| return torch.empty([]), 1, 0.1, True |
| |
| self.assertExpectedInline( |
| infer_schema(b, mutates_args=()), |
| """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""", |
| ) |
| |
| def c( |
| x: Tensor, |
| y: Sequence[Tensor], |
| z: Optional[Tensor], |
| w: Sequence[Optional[Tensor]], |
| ) -> List[Tensor]: |
| return [torch.empty([])] |
| |
| self.assertExpectedInline( |
| infer_schema(c, mutates_args=()), |
| """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""", |
| ) |
| |
| def d(x: Tensor) -> Tuple[List[Tensor], Tensor]: |
| return [torch.empty([])], torch.empty([]) |
| |
| self.assertExpectedInline( |
| infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)""" |
| ) |
| |
| def e() -> Tensor: |
| return torch.empty([]) |
| |
| self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""") |
| |
| def f(x: Tensor) -> None: |
| pass |
| |
| self.assertExpectedInline( |
| infer_schema(f, mutates_args=()), """(Tensor x) -> ()""" |
| ) |
| |
| def g( |
| x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]] |
| ) -> None: |
| pass |
| |
| self.assertExpectedInline( |
| infer_schema(g, mutates_args=()), |
| """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""", |
| ) |
| |
| self.assertExpectedInline( |
| infer_schema(g, mutates_args={"x", "w", "z"}), |
| """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", |
| ) |
| |
| self.assertExpectedInline( |
| infer_schema(g, mutates_args="unknown"), |
| """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", |
| ) |
| |
| def h( |
| x: Tensor, |
| a: Optional[int] = None, |
| b: float = 3.14, |
| c: bool = True, |
| d: int = 3, |
| e: str = "foo", |
| f: torch.dtype = torch.float, |
| g: torch.dtype = torch.float32, |
| h: torch.dtype = torch.int, |
| i: torch.device = torch.device("cpu:0"), |
| j: torch.device = "cpu", |
| ) -> None: |
| pass |
| |
| self.assertExpectedInline( |
| infer_schema(h, mutates_args=()), |
| ( |
| """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """ |
| """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()""" |
| ), |
| ) |
| |
| def foo_impl(x: torch.Tensor) -> torch.Tensor: |
| return x.sin() |
| |
| schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) |
| self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") |
| |
| def test_infer_schema_unsupported(self): |
| with self.assertRaisesRegex(ValueError, "varargs"): |
| |
| def foo(*args): |
| raise NotImplementedError |
| |
| infer_schema(foo, mutates_args=()) |
| |
| with self.assertRaisesRegex(ValueError, "varkwargs"): |
| |
| def foo(**kwargs): |
| raise NotImplementedError |
| |
| infer_schema(foo, mutates_args=()) |
| |
| with self.assertRaisesRegex(ValueError, "must have a type annotation"): |
| |
| def foo(x): |
| raise NotImplementedError |
| |
| infer_schema(foo, mutates_args=()) |
| |
| with self.assertRaisesRegex(ValueError, "unsupported"): |
| |
| def foo(x: Tensor) -> Tuple[Tensor, ...]: |
| raise NotImplementedError |
| |
| infer_schema(foo, mutates_args=()) |
| |
| with self.assertRaisesRegex(ValueError, "can be mutated"): |
| |
| def foo(x: Tensor, y: int) -> Tensor: |
| raise NotImplementedError |
| |
| infer_schema(foo, mutates_args={"y"}) |
| |
| def _generate_examples(self, typ): |
| if typ is int: |
| return [17] |
| if typ is float: |
| return [3.14] |
| if typ is bool: |
| return [True] |
| if typ is str: |
| return ["foo"] |
| if typ is torch.dtype: |
| return [torch.float32] |
| if typ is torch.device: |
| return [torch.device("cpu")] |
| if typ == torch.types.Number: |
| return [2.718] |
| if typ is torch.Tensor: |
| return [torch.tensor(3)] |
| if typ == Optional[torch.types.Number]: |
| return [None, 2.718] |
| origin = typing.get_origin(typ) |
| if origin is Union: |
| args = typing.get_args(typ) |
| assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None)) |
| elt = args[0] if args[1] is type(None) else args[1] |
| return self._generate_examples(elt) + [None] |
| if origin is list: |
| args = typing.get_args(typ) |
| assert len(args) == 1 |
| elt = args[0] |
| return [ |
| self._generate_examples(elt), |
| self._generate_examples(elt), |
| self._generate_examples(elt), |
| ] |
| if origin is collections.abc.Sequence: |
| args = typing.get_args(typ) |
| assert len(args) == 1 |
| examples = self._generate_examples(args[0]) |
| return list(itertools.product(examples, examples)) + [] |
| raise NotImplementedError( |
| f"testrunner cannot generate instanstance of type {typ}" |
| ) |
| |
| def test_supported_return_types_single_return(self): |
| for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: |
| for example in self._generate_examples(typ): |
| try: |
| |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: Tensor) -> typ: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{self.test_ns}::foo") |
| def foo_impl(x: Tensor) -> typ: |
| return example |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(torch.randn([])) |
| self.assertEqual(result, example, msg=f"{typ} {example}") |
| finally: |
| custom_ops._destroy(f"{self.test_ns}::foo") |
| |
| def test_supported_return_types_multi_return(self): |
| for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: |
| for example in self._generate_examples(typ): |
| try: |
| |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: Tensor) -> Tuple[typ, typ]: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{self.test_ns}::foo") |
| def foo_impl(x: Tensor) -> Tuple[typ, typ]: |
| return (example, example) |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(torch.randn([])) |
| expected = (example, example) |
| self.assertEqual(result, expected, msg=f"{typ} {example}") |
| finally: |
| custom_ops._destroy(f"{self.test_ns}::foo") |
| |
| def test_supported_param_types(self): |
| for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES: |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: typ) -> Tensor: |
| raise NotImplementedError |
| |
| yeet = None |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"]) |
| def foo_cpu(x, y): |
| nonlocal yeet |
| yeet = y |
| return x.clone() |
| |
| try: |
| for example in self._generate_examples(typ): |
| op = self.get_op(f"{self.test_ns}::foo") |
| op(torch.randn([]), example) |
| self.assertEqual(yeet, example, msg=f"{typ} {example}") |
| yeet = None |
| finally: |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_sequences(self): |
| # Sequence[int] gets automagically turned into int[] in the schema. |
| # This test checks that we actually do support arbitrary sequence types. |
| class MySequence(collections.abc.Sequence): |
| def __init__(self) -> None: |
| self._container = [1, 2, 3] |
| |
| def __getitem__(self, idx): |
| return self._container[idx] |
| |
| def __len__(self): |
| return len(self._container) |
| |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| called = 0 |
| |
| @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x, sizes): |
| nonlocal called |
| called += 1 |
| # Dispatcher will normalize the sequence type into a List |
| self.assertEqual(sizes, [1, 2, 3]) |
| return x.clone() |
| |
| x = torch.randn([]) |
| seq = MySequence() |
| op = self.get_op(f"{self.test_ns}::foo") |
| op(x, seq) |
| self.assertEqual(called, 1) |
| |
| def test_unsupported_param_types(self): |
| # Not comprehensive (it doesn't need to be), just a check that our mechanism works |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: List[Optional[int]]) -> Tensor: |
| raise NotImplementedError |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| # int[N] in Dispatcher is a bit wild, so we don't try to support it. |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: |
| raise NotImplementedError |
| |
| del foo |
| |
| with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"): |
| # test that we propose a correct and supported type. |
| @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) |
| def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: |
| raise NotImplementedError |
| |
| del foo |
| |
| with self.assertRaises(ValueError) as cm: |
| |
| @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) |
| def foo(x: Tensor, y: Tuple[int, float]) -> Tensor: |
| raise NotImplementedError |
| |
| del foo |
| |
| self.assertNotIn("example", str(cm.exception), "") |
| |
| with self.assertRaisesRegex(ValueError, "unsupported type"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Tensor, y: Callable) -> Tensor: |
| raise NotImplementedError |
| |
| del foo |
| |
| def test_supported_schemas(self): |
| # All of these should already be tested by PyTorch codegen |
| # (we share the same mechanism), but here's a sanity check. |
| schemas = [ |
| "(Tensor x) -> Tensor", |
| "(Tensor x) -> Tensor y", |
| "(Tensor[] x) -> Tensor y", |
| "(Tensor x) -> (Tensor, Tensor)", |
| "(Tensor x) -> (Tensor y, Tensor z)", |
| "(Tensor x) -> (Tensor y, Tensor z)", |
| ] |
| other_schemas = [ |
| "(Tensor x, Tensor w) -> (Tensor y, Tensor z)", |
| "(Tensor x, Tensor w) -> (Tensor, Tensor)", |
| "(Tensor x, Tensor w) -> Tensor", |
| "(Tensor? x, Tensor w) -> Tensor", |
| "(Tensor? x, Tensor[] w) -> Tensor", |
| "(Tensor x, int[] w) -> Tensor", |
| "(Tensor x, SymInt[] w) -> Tensor", |
| "(Tensor x, Scalar w) -> Tensor", |
| "(Tensor x, float w) -> Tensor", |
| "(Tensor x, float? w) -> Tensor", |
| "(Tensor x, bool[] w) -> Tensor", |
| ] |
| |
| for schema in schemas: |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| for schema in other_schemas: |
| custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::bar") |
| |
| def test_reserved_ns(self): |
| from torch._custom_op.impl import RESERVED_NS |
| |
| for ns in RESERVED_NS: |
| with self.assertRaisesRegex(ValueError, "is a reserved namespace"): |
| custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor") |
| |
| with self.assertRaisesRegex(ValueError, "is a reserved namespace"): |
| |
| @custom_ops.custom_op(f"{ns}::foo2") |
| def foo2(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| def test_private_ctor(self): |
| with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"): |
| CustomOp(None, None, None, None, None) |
| |
| def test_lifetime(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo") |
| |
| # We can't define an op multiple times, |
| with self.assertRaisesRegex(RuntimeError, "multiple times"): |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError |
| |
| # Unless we delete the original op. |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| # Smoke test |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError |
| |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_autograd_notimplemented(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 |
| raise NotImplementedError |
| |
| x = torch.randn(3, requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op(x) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| del foo |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op([y, x]) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| del foo |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): |
| op(y, x) |
| custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") |
| |
| def test_autograd_notimplemented_gradmode(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x * y |
| |
| x = torch.randn(3, requires_grad=True) |
| y = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with torch.no_grad(): |
| # Shouldn't raise, because we are in no_grad |
| op(y, x) |
| |
| def test_impl_cpu(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x): |
| return x.sin() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x) |
| self.assertEqual(result, foo_cpu(x)) |
| |
| def test_impl_invalid_devices(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY |
| |
| for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): |
| # Smoke test: should not raise error |
| custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( |
| foo_impl |
| ) |
| |
| # Not supported by this API: we can either support them in the future |
| # or provide some other CustomOp.def_* function. This depends on how |
| # common the use cases are. |
| for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]: |
| with self.assertRaisesRegex(ValueError, "we only support device_type"): |
| custom_ops.impl( |
| f"{TestCustomOp.test_ns}::foo", device_types=invalid_type |
| )(foo_impl) |
| |
| def test_backward_partially_registered(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return grad * saved.cos() |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex( |
| RuntimeError, "unable to find a 'save_for_backward'" |
| ): |
| y = op(x) |
| y.backward() |
| |
| def test_save_for_backward_inputs_are_namedtuple(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| hit = 0 |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| nonlocal hit |
| hit += 1 |
| self.assertTrue(isinstance(inputs, tuple)) |
| self.assertEqual(list(inputs._asdict().keys()), ["x"]) |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| self.assertEqual(hit, 1) |
| y.backward() |
| self.assertEqual(hit, 1) |
| |
| def test_backward_returns_dict(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return grad * saved.cos() |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "to be a dict"): |
| y.backward() |
| |
| def test_backward_dict_invalid_keys(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos(), "y": None} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"): |
| y.backward() |
| |
| def test_backward_dict_grad_for_nontensor(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, dim): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos(), "dim": None} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, 32) |
| with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"): |
| y.backward() |
| |
| def test_backward_dict_requires_keys_for_input_tensors(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, x) |
| with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): |
| y.backward() |
| |
| def test_backward_dict_requires_keys_for_input_optional_tensors(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x, y): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": grad * saved.cos()} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x, None) |
| with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): |
| y.backward() |
| |
| def test_backward_grads_are_tensor_or_none(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"x": (grad * saved.cos(),)} |
| |
| x = torch.randn([], requires_grad=True) |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(x) |
| with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": [grad * saved.cos(), None]} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": [grad * saved.cos(), None, (None,)]} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "None or Tensor"): |
| y.backward() |
| |
| def test_backward_tensorlist_input_requires_list_grads(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(xs): |
| return xs[0].sin() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return inputs.xs[0] |
| |
| @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| xs = [torch.randn([], requires_grad=True) for _ in range(3)] |
| op = self.get_op(f"{self.test_ns}::foo") |
| y = op(xs) |
| with self.assertRaisesRegex(RuntimeError, "list of gradients"): |
| y.backward() |
| |
| def test_backward_output_differentiability_type(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| with self.assertRaisesRegex(RuntimeError, "output_differentiability"): |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=True |
| ) |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| def test_backward_output_differentiability_numel(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: |
| raise NotImplementedError |
| |
| with self.assertRaisesRegex(RuntimeError, "output_differentiability"): |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=[True] |
| ) |
| def foo_backward(ctx, saved, grad): |
| return {"xs": None} |
| |
| def test_backward_output_differentiability_tensorlist(self): |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{self.test_ns}::foo") |
| def foo_impl(x): |
| return [x.clone(), x.clone()], x.clone() |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return [] |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True] |
| ) |
| def foo_backward(ctx, saved, grad_lst, grad): |
| return {"x": grad} |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| x = torch.randn(3, requires_grad=True) |
| [a, b], c = op(x) |
| self.assertFalse(a.requires_grad) |
| self.assertFalse(b.requires_grad) |
| self.assertTrue(c.requires_grad) |
| |
| def test_backward_output_differentiability_non_tensor(self): |
| @custom_ops.custom_op(f"{self.test_ns}::foo") |
| def foo(x: Tensor) -> Tuple[Tensor, int]: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{self.test_ns}::foo") |
| def foo_impl(x): |
| return x.clone(), 3 |
| |
| @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") |
| def foo_save_for_backward(inputs, output): |
| return [] |
| |
| @custom_ops.impl_backward( |
| f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True] |
| ) |
| def foo_backward(ctx, saved, grad0, grad1): |
| return {"x": grad0} |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| x = torch.randn(3, requires_grad=True) |
| with self.assertRaisesRegex(RuntimeError, "is not a Tensor"): |
| op(x) |
| |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_impl_separate(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") |
| def foo_cpu(x): |
| return x.sin() |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") |
| def foo_cuda(x): |
| return x.cos() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x) |
| self.assertEqual(result, foo_cpu(x)) |
| |
| x_cuda = x.cuda() |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x_cuda) |
| self.assertEqual(result, foo_cuda(x_cuda)) |
| |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_impl_multiple(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") |
| def foo_impl(x): |
| return x.cos() |
| |
| op = self.get_op(f"{self.test_ns}::foo") |
| x = torch.randn(3) |
| result = op(x) |
| self.assertEqual(result, foo_impl(x)) |
| |
| x_cuda = x.cuda() |
| result = op(x_cuda) |
| self.assertEqual(result, foo_impl(x_cuda)) |
| |
| def test_impl_abstract_overload(self): |
| lib = self.lib() |
| lib.define("sin.blah(Tensor x) -> Tensor") |
| |
| torch.library.impl_abstract( |
| f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib |
| ) |
| |
| op = self.ns().sin.blah |
| x = torch.randn(3, device="meta") |
| op(x) |
| |
| def test_impl_meta(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) |
| def foo_meta(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| x = torch.randn(2, 3, device="meta") |
| op = self.get_op(f"{self.test_ns}::foo") |
| result = op(x, 1) |
| self.assertEqual(result.shape, foo_meta(x, 1).shape) |
| |
| def test_duplicate_impl(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor, dim: int) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) |
| def foo_meta(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"): |
| |
| @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) |
| def foo_meta2(x, dim): |
| output_shape = list(x.shape) |
| del output_shape[dim] |
| return x.new_empty(output_shape) |
| |
| def test_new_data_dependent_symint(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) |
| def foo_meta(x): |
| ctx = torch.library.get_ctx() |
| r = ctx.new_dynamic_size(min=1) |
| with self.assertRaisesRegex(ValueError, "greater than or equal to 0"): |
| ctx.new_dynamic_size(min=-1) |
| with self.assertRaisesRegex(ValueError, "SymInt"): |
| ctx.new_dynamic_size(max=x.numel()) |
| # NB: You must return dynamic sizes! |
| return x.new_empty(r) |
| |
| x = torch.randn(2, 3, device="cpu") |
| op = self.get_op(f"{self.test_ns}::foo") |
| make_fx(op, tracing_mode="symbolic")(x) |
| |
| def test_meta_for_data_dependent_shape_operation(self): |
| x = torch.randn(10, device="meta") |
| with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"): |
| numpy_nonzero(x) |
| |
| def test_basic_make_fx(self): |
| # More serious tests are in our CustomOp opinfo db, |
| # this one is just a sanity check. |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) |
| def foo_meta(x): |
| return x.sum() |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| gm = make_fx(op, tracing_mode="symbolic")(x) |
| self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code) |
| |
| def test_not_implemented_error(self): |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") |
| def foo(x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
| |
| x = torch.randn(3) |
| op = self.get_op(f"{self.test_ns}::foo") |
| with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"): |
| op(x) |
| |
| x = torch.randn(3, device="meta") |
| with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"): |
| op(x) |
| |
| @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar") |
| def bar(sizes: Sequence[int]) -> torch.Tensor: |
| raise NotImplementedError |
| |
| op = self.get_op(f"{self.test_ns}::bar") |
| with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"): |
| op((1, 2, 3)) |
| |
| def test_data_dependent_basic(self): |
| x = torch.randn(5, 5) |
| gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x) |
| self.assertTrue("nonzero" in gm.code) |
| |
| def test_data_dependent_fake_tracing(self): |
| x = torch.randn(5, 5) |
| # We've updated to attempt to use unbacked symints even for fake |
| # tracing |
| make_fx(numpy_nonzero, tracing_mode="fake")(x) |
| |
| def test_symints(self): |
| def f(x): |
| return torch.ops._torch_testing.numpy_view_copy(x, x.shape) |
| |
| x = torch.randn(2, 3, 4) |
| gm = make_fx(f, tracing_mode="symbolic")(x) |
| result = gm(x) |
| self.assertEqual(result, f(x)) |
| self.assertExpectedInline( |
| gm.code.strip(), |
| """\ |
| def forward(self, x_1): |
| sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) |
| sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) |
| sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2) |
| numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None |
| return numpy_view_copy""", # noqa: B950 |
| ) |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") |
| def test_data_dependent_compile(self): |
| import torch._dynamo.testing |
| from torch._dynamo.utils import counters |
| |
| counters.clear() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt) |
| def f(x): |
| return numpy_nonzero(x.clone()).clone() |
| |
| f(torch.randn(10)) |
| |
| self.assertEqual(len(counters["graph_break"]), 1) |
| self.assertEqual(next(iter(counters["graph_break"].values())), 1) |
| self.assertExpectedInline( |
| next(iter(counters["graph_break"].keys())).replace(";", "\n"), |
| """\ |
| dynamic shape operator: _torch_testing.numpy_nonzero.default |
| to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""", |
| ) |
| |
| # pre-existing problem: torch.compile(dynamic=True) will, by default, |
| # graph break on data-dependent operations. Eventually we'll make it so |
| # that it never graph breaks on data-dependent operations. |
| @unittest.expectedFailure |
| def test_data_dependent_nms_dynamic_compile(self): |
| import torch._dynamo.testing |
| from torch._dynamo.utils import counters |
| |
| counters.clear() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt, dynamic=True) |
| def f(x, s, i): |
| return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone() |
| |
| f(torch.randn(20, 4), torch.randn(20), 0.1) |
| |
| self.assertEqual(len(counters["graph_break"]), 0) |
| |
| def test_impl_on_existing_op(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| @torch._custom_ops.impl(qualname) |
| def foo_impl(x): |
| return x.sin() |
| |
| op = self.get_op(qualname) |
| x = torch.randn(3) |
| result = op(x) |
| self.assertEqual(result, x.sin()) |
| |
| @parametrize( |
| "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"] |
| ) |
| def test_impl_on_existing_op_with_cpu_registration(self, key): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", foo_impl, key) |
| op = self.get_op(qualname) |
| |
| with self.assertRaisesRegex(RuntimeError, "already has an implementation"): |
| custom_ops.impl(qualname, func=foo_impl) |
| |
| def test_abstract_impl_on_existing_op(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| @torch.library.impl_abstract(qualname, lib=self.lib()) |
| def foo_impl(x): |
| return x.sin() |
| |
| op = self.get_op(qualname) |
| with torch._subclasses.FakeTensorMode(): |
| x = torch.randn(3) |
| result = op(x) |
| self.assertEqual(result.shape, x.shape) |
| self.assertEqual(result.stride(), x.stride()) |
| |
| def test_abstract_impl_on_existing_op_with_meta(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", foo_impl, "Meta") |
| op = self.get_op(qualname) |
| |
| with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"): |
| torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) |
| |
| def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", foo_impl, "CompositeImplicitAutograd") |
| op = self.get_op(qualname) |
| |
| with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"): |
| torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) |
| |
| def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| def foo_impl(x): |
| return x.sin() |
| |
| lib.impl("foo", foo_impl, "CompositeExplicitAutograd") |
| op = self.get_op(qualname) |
| |
| torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib()) |
| with torch._subclasses.FakeTensorMode(): |
| x = torch.randn(10) |
| result = op(x) |
| self.assertEqual(result.shape, ()) |
| |
| def _test_backward_impl_raises(self, qualname, err_regex): |
| with self.assertRaisesRegex(RuntimeError, err_regex): |
| |
| @custom_ops.impl_save_for_backward(qualname) |
| def foo2(x): |
| return |
| |
| with self.assertRaisesRegex(RuntimeError, err_regex): |
| |
| @custom_ops.impl_backward(qualname) |
| def foo3(x): |
| return |
| |
| def test_backward_impl_on_existing_op_incorrect_schema_views(self): |
| lib = self.lib() |
| lib.define("foo(Tensor(a) x) -> Tensor(a)") |
| qualname = f"{self.test_ns}::foo" |
| self._test_backward_impl_raises(qualname, "operator that returns views") |
| |
| def test_backward_impl_on_existing_op_incorrect_schema_mutable(self): |
| lib = self.lib() |
| lib.define("foo(Tensor(a!) x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| self._test_backward_impl_raises(qualname, "non-functional") |
| |
| def test_backward_impl_on_existing_op_incorrect_schema_no_output(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> ()") |
| qualname = f"{self.test_ns}::foo" |
| self._test_backward_impl_raises(qualname, "no returns") |
| |
| def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd") |
| self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd") |
| |
| @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"]) |
| def test_backward_impl_on_existing_op_with_key(self, key): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| lib.impl("foo", lambda x: x.sin().cos(), key) |
| self._test_backward_impl_raises(qualname, key) |
| |
| def test_is_functional_schema(self): |
| tests = { |
| "foo(Tensor x) -> Tensor": True, |
| "foo(Tensor(a) x) -> Tensor": True, |
| "foo(Tensor(a!) x) -> Tensor": False, |
| "foo(Tensor(a) x) -> Tensor(a)": False, |
| "foo(Tensor x) -> ()": False, |
| } |
| for schema_str, expected in tests.items(): |
| res = torch._library.utils.is_functional_schema(schema_str) |
| self.assertEqual(res, expected) |
| |
| from torchgen.model import FunctionSchema |
| |
| schema = FunctionSchema.parse(schema_str) |
| res = torch._library.utils.is_functional_schema(schema) |
| self.assertEqual(res, expected) |
| |
| schema = torch._C.parse_schema(schema_str) |
| res = torch._library.utils.is_functional_schema(schema) |
| self.assertEqual(res, expected) |
| |
| def test_incorrect_schema_types(self): |
| with torch.library._scoped_library("mylib", "FRAGMENT") as lib: |
| with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): |
| lib.define("foo12(Tensor a) -> asdfasdf") |
| with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): |
| lib.define("foo12(asdf a) -> Tensor") |
| with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"): |
| lib.define("foo12(int64_t a) -> Tensor") |
| with self.assertRaisesRegex(RuntimeError, "Use `float`"): |
| lib.define("foo12(double a) -> Tensor") |
| |
| def test_is_tensorlist_like_type(self): |
| tensorlists = [ |
| # Tensor[] |
| torch.ops.aten.where.default._schema.returns[0].type, |
| # Tensor?[] |
| torch.ops.aten.index.Tensor._schema.arguments[1].type, |
| # Tensor[]? |
| torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type, |
| # Tensor?[]? |
| torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type, |
| ] |
| non_tensorlists = [ |
| # Tensor |
| torch.ops.aten.sin.default._schema.arguments[0].type, |
| # IntList |
| torch.ops.aten.sum.dim_IntList._schema.arguments[1].type, |
| ] |
| for a in tensorlists: |
| self.assertTrue(torch._library.utils.is_tensorlist_like_type(a)) |
| for a in non_tensorlists: |
| self.assertFalse(torch._library.utils.is_tensorlist_like_type(a)) |
| |
| def test_backward_impl_on_existing_op(self): |
| lib = self.lib() |
| lib.define("foo(Tensor x) -> Tensor") |
| qualname = f"{self.test_ns}::foo" |
| |
| @custom_ops.impl(qualname) |
| def foo_impl(x): |
| with torch.no_grad(): |
| return x.sin() |
| |
| @custom_ops.impl_save_for_backward(qualname) |
| def foo_save_for_backward(inputs, output): |
| return inputs.x |
| |
| @custom_ops.impl_backward(qualname) |
| def foo_backward(ctx, saved, grad_out): |
| return {"x": grad_out * saved.cos()} |
| |
| op = self.get_op(qualname) |
| x = torch.randn([], requires_grad=True) |
| y = op(x) |
| (gx,) = torch.autograd.grad(y, x) |
| self.assertEqual(gx, x.cos()) |
| |
| @parametrize( |
| "tags", |
| [ |
| subtest(torch.Tag.pointwise, "single"), |
| subtest((torch.Tag.pointwise,), "tuple"), |
| subtest([torch.Tag.pointwise], "list"), |
| ], |
| ) |
| def test_define_with_tags(self, tags): |
| lib = self.lib() |
| tags = (torch.Tag.pointwise,) |
| torch.library.define( |
| f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags |
| ) |
| actual = self.ns().foo.default.tags |
| self.assertTrue(isinstance(actual, list)) |
| self.assertEqual(actual, list(tags)) |
| |
| def test_builtin_aten_ops_are_pt2_compliant(self): |
| for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]: |
| self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) |
| |
| def test_builtin_torchscript_ops(self): |
| for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]: |
| self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) |
| |
| def test_autogen_aten_ops_are_pt2_compliant(self): |
| for op in [torch.ops.aten.fill.Tensor_out]: |
| self.assertIn(torch.Tag.generated, op.tags) |
| self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) |
| |
| def test_resolve_packet(self): |
| x = torch.randn(3) |
| result = torch._C._jit_resolve_packet("aten::sum", x) |
| self.assertEqual(result, "default") |
| |
| result = torch._C._jit_resolve_packet("aten::sum", x, dim=1) |
| self.assertEqual(result, "dim_IntList") |
| |
| with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): |
| result = torch._C._jit_resolve_packet("aten::sum", x, x, x) |
| |
| def test_define_bad_schema(self): |
| lib = self.lib() |
| with self.assertRaisesRegex(ValueError, "expected schema to look like"): |
| torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor") |
| |
| def test_define_and_impl(self): |
| lib = self.lib() |
| torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) |
| |
| @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib) |
| def f(x): |
| return torch.from_numpy(np.sin(x.numpy())) |
| |
| x = torch.randn(3) |
| y = self.ns().foo(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_define_validation(self): |
| with self.assertRaisesRegex(ValueError, "namespace"): |
| torch.library.define("foo", "(Tensor x) -> Tensor") |
| |
| def test_legacy_define(self): |
| lib = self.lib() |
| |
| @torch.library.define(lib, "foo(Tensor x) -> Tensor") |
| def f(x): |
| return torch.from_numpy(np.sin(x.numpy())) |
| |
| x = torch.randn(3) |
| y = self.ns().foo(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_impl_function(self): |
| lib = self.lib() |
| torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) |
| |
| def f(x): |
| return torch.from_numpy(np.sin(x.numpy())) |
| |
| torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib) |
| x = torch.randn(3) |
| y = self.ns().foo(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_legacy_impl(self): |
| lib = self.lib() |
| torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) |
| |
| @torch.library.impl(lib, "foo", "CPU") |
| def f(x): |
| return torch.from_numpy(np.sin(x.numpy())) |
| |
| x = torch.randn(3) |
| y = self.ns().foo(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_defined_in_python(self): |
| self.assertFalse(torch.ops.aten.sin.default._defined_in_python) |
| self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python) |
| |
| lib = self.lib() |
| torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) |
| ns = self.ns() |
| self.assertTrue(ns.foo.default._defined_in_python) |
| |
| torch.library.define( |
| "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib |
| ) |
| self.assertTrue(ns.bar.overload._defined_in_python) |
| |
| def _test_impl_device(self, name, types, device): |
| lib = self.lib() |
| torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib) |
| |
| @torch.library.impl(f"{self.test_ns}::{name}", types) |
| def f(x): |
| x_np = x.cpu().numpy() |
| y = torch.from_numpy(np.sin(x_np)) |
| return y.to(device=x.device) |
| |
| x = torch.randn(3, device=device) |
| y = getattr(self.ns(), name)(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_impl_device_cpu(self): |
| self._test_impl_device("foo1", "default", "cpu") |
| self._test_impl_device("foo2", ["cpu"], "cpu") |
| self._test_impl_device("foo3", ["cpu", "cuda"], "cpu") |
| |
| @unittest.skipIf(not TEST_CUDA, "requires cuda") |
| def test_impl_device_cuda(self): |
| self._test_impl_device("foo4", "default", "cuda") |
| self._test_impl_device("foo5", ["cuda"], "cuda") |
| self._test_impl_device("foo6", ["cpu", "cuda"], "cuda") |
| |
| def test_impl_device_function(self): |
| lib = self.lib() |
| torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) |
| |
| def f(x): |
| x_np = x.cpu().numpy() |
| y = torch.from_numpy(np.sin(x_np)) |
| return y.to(device=x.device) |
| |
| torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib) |
| x = torch.randn(3) |
| y = self.ns().foo(x) |
| assert torch.allclose(y, x.sin()) |
| |
| def test_impl_device_invalid(self): |
| with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"): |
| torch.library.impl("blah::blah", "somethingsomething") |
| |
| def test_autograd_function_backed_op(self): |
| cpp_source = """ |
| struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { |
| static constexpr bool is_traceable = true; |
| |
| static torch::Tensor forward( |
| torch::autograd::AutogradContext* ctx, |
| const torch::Tensor& x) { |
| return x; |
| } |
| |
| static torch::autograd::variable_list backward( |
| torch::autograd::AutogradContext *ctx, |
| torch::autograd::variable_list grad_output) { |
| return grad_output; |
| } |
| }; |
| |
| torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { |
| return CustomOpAutogradFunction::apply(x); |
| } |
| |
| TORCH_LIBRARY(mylib, m) { |
| m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); |
| } |
| """ |
| |
| module = torch.utils.cpp_extension.load_inline( |
| name="mylib", |
| cpp_sources=cpp_source, |
| functions="custom_op_backed_by_autograd_fn", |
| verbose=True, |
| ) |
| |
| x = torch.ones(2, 2, requires_grad=True) |
| temp = x.clone().detach() |
| out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x) |
| loss = out.sum() |
| loss.backward() |
| self.assertEqual(x.grad, temp) |
| |
| |
| def op_with_incorrect_schema(testcase, name): |
| lib = testcase.lib() |
| lib.define(f"{name}(Tensor x) -> Tensor") |
| qualname = f"{testcase.test_ns}::{name}" |
| lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd") |
| return testcase.get_op(qualname) |
| |
| |
| class MiniOpTest(CustomOpTestCaseBase): |
| test_ns = "mini_op_test" |
| |
| def _init_op_delayed_backward_error(self): |
| name = "delayed_error" |
| qualname = f"{self.test_ns}::{name}" |
| lib = self.lib() |
| lib.define(f"{name}(Tensor x) -> Tensor") |
| lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd") |
| op = self.get_op(qualname) |
| |
| class Op(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| with torch._C._AutoDispatchBelowAutograd(): |
| return op(x) |
| |
| @staticmethod |
| def backward(ctx, grad): |
| raise NotImplementedError |
| |
| def autograd_impl(x): |
| return Op.apply(x) |
| |
| lib.impl(name, autograd_impl, "Autograd") |
| return op |
| |
| def _init_op_with_no_abstract_impl(self): |
| name = "no_abstract" |
| qualname = f"{self.test_ns}::{name}" |
| lib = self.lib() |
| lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)) |
| lib.impl(name, lambda x: x.clone(), "CPU") |
| return torch._library.utils.lookup_op(qualname) |
| |
| def setUp(self): |
| super().setUp() |
| self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl() |
| self._op_delayed_backward_error = self._init_op_delayed_backward_error() |
| |
| @optests.dontGenerateOpCheckTests("Testing this API") |
| def test_dont_generate(self): |
| op = op_with_incorrect_schema(self, "incorrect_schema") |
| x = torch.randn(3) |
| op(x) |
| |
| def test_mm(self): |
| x = torch.randn(2, 3, requires_grad=True) |
| y = torch.randn(3, 5) |
| result = torch.ops.aten.mm.default(x, y) |
| self.assertEqual(result, x @ y) |
| |
| def test_mm_meta(self): |
| x = torch.randn(2, 3, requires_grad=True, device="meta") |
| y = torch.randn(3, 5, device="meta") |
| result = torch.ops.aten.mm.default(x, y) |
| self.assertEqual(result.shape, (x @ y).shape) |
| |
| def test_mm_fake(self): |
| with torch._subclasses.fake_tensor.FakeTensorMode(): |
| x = torch.randn(2, 3, requires_grad=True, device="cpu") |
| y = torch.randn(3, 5, device="cpu") |
| result = torch.ops.aten.mm.default(x, y) |
| self.assertEqual(result.shape, (x @ y).shape) |
| |
| def test_mm_errors(self): |
| x = torch.randn(2, 3, requires_grad=True) |
| y = torch.randn(4, 5) |
| with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"): |
| result = torch.ops.aten.mm.default(x, y) |
| |
| def test_nonzero(self): |
| x = torch.tensor([0, 1, 2, 0, 0]) |
| y = torch.ops.aten.nonzero.default(x) |
| self.assertEqual(y, torch.tensor([[1], [2]])) |
| |
| def test_inplace(self): |
| x = torch.randn(3) |
| x_clone = x.clone() |
| y = torch.ops.aten.sin_(x) |
| self.assertEqual(x, x_clone.sin()) |
| |
| def test_incorrect_schema(self): |
| op = op_with_incorrect_schema(self, "incorrect_schema") |
| x = torch.randn(3) |
| op(x) |
| |
| def test_no_abstract(self): |
| op = self._op_with_no_abstract_impl |
| x = torch.randn(3) |
| op(x) |
| |
| def test_delayed_error(self): |
| op = self._op_delayed_backward_error |
| x = torch.randn([], requires_grad=True) |
| y = op(x) |
| with self.assertRaises(NotImplementedError): |
| y.sum().backward() |
| |
| def test_delayed_error_no_requires_grad(self): |
| op = self._op_delayed_backward_error |
| x = torch.randn([]) |
| y = op(x) |
| |
| |
| class TestCustomOpAPI(TestCase): |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_basic(self): |
| @torch.library.custom_op("_torch_testing::add", mutates_args=()) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.numpy(force=True) |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z, x + y) |
| |
| cpu_called = False |
| |
| @add.register_kernel("cpu") |
| def _(x, y): |
| nonlocal cpu_called |
| cpu_called = True |
| x_np = x.numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np) |
| |
| z = add(x, y) |
| self.assertEqual(z, x + y) |
| self.assertTrue(cpu_called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_no_grad_skips_autograd(self): |
| @torch.library.custom_op("_torch_testing::add", mutates_args=()) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.numpy(force=True) |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| called = 0 |
| |
| def setup_context(ctx, inputs, output): |
| nonlocal called |
| called += 1 |
| |
| def backward(ctx, grad): |
| raise AssertionError("should not be reached") |
| |
| add.register_autograd(backward, setup_context=setup_context) |
| |
| x = torch.randn(3, requires_grad=True) |
| with torch.no_grad(): |
| y = add(x, 2.0) |
| self.assertEqual(called, 0) |
| self.assertEqual(y, x + 2.0) |
| |
| x.requires_grad_(False) |
| y = add(x, 2.0) |
| self.assertEqual(called, 0) |
| self.assertEqual(y, x + 2.0) |
| |
| x = torch.randn(3, requires_grad=True) |
| y = add(x, 2.0) |
| self.assertEqual(called, 1) |
| self.assertEqual(y, x + 2.0) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_manual_schema(self): |
| @torch.library.custom_op( |
| "_torch_testing::add", |
| mutates_args=(), |
| schema="(Tensor x, float y) -> Tensor", |
| ) |
| def add(x, y): |
| x_np = x.numpy(force=True) |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z, x + y) |
| |
| @torch.library.custom_op( |
| "_torch_testing::sin_", |
| mutates_args=["x"], |
| schema="(Tensor(a!) x) -> ()", |
| ) |
| def sin_(x): |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| |
| x = torch.randn(3) |
| expected = x.sin() |
| sin_(x) |
| self.assertEqual(x, expected) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_kwarg_only_tensors(self): |
| with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): |
| |
| @torch.library.custom_op("_torch_testing::foo", mutates_args=()) |
| def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor: |
| pass |
| |
| with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): |
| |
| @torch.library.custom_op("_torch_testing::foo", mutates_args=()) |
| def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor: |
| pass |
| |
| with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): |
| |
| @torch.library.custom_op("_torch_testing::foo", mutates_args=()) |
| def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor: |
| pass |
| |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("foo(Tensor x, *, Tensor y) -> Tensor") |
| with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): |
| torch.library.register_autograd( |
| "_torch_testing::foo", |
| lambda grad: grad, |
| setup_context=lambda ctx, inputs, keyword_only_inputs, output: None, |
| ) |
| |
| with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): |
| torch.library.register_vmap( |
| "_torch_testing::foo", |
| lambda info, in_dims, x, *, y: (x, 0), |
| ) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_register_autograd_kwargonly_low_level(self): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("foo(Tensor x, *, float y) -> Tensor") |
| called = False |
| |
| def foo_impl(x, *, y): |
| return x * y |
| |
| lib.impl("foo", foo_impl, "CPU") |
| |
| def backward(ctx, grad): |
| nonlocal called |
| called = True |
| return grad * ctx.y |
| |
| def setup_context(ctx, inputs, keyword_only_inputs, output): |
| assert tuple(keyword_only_inputs.keys()) == ("y",) |
| ctx.y = keyword_only_inputs["y"] |
| |
| torch.library.register_autograd( |
| "_torch_testing::foo", backward, setup_context=setup_context, lib=lib |
| ) |
| |
| x = torch.randn(3, requires_grad=True) |
| torch.ops._torch_testing.foo(x, y=3.14).sum().backward() |
| self.assertTrue(called) |
| self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14])) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_register_autograd_defaults(self): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") |
| |
| def foo_impl(w, x=2, *, y=3, z): |
| return w * x * y * z |
| |
| lib.impl("foo", foo_impl, "CPU") |
| |
| called = False |
| |
| def backward(ctx, grad): |
| nonlocal called |
| called = True |
| return grad * ctx.c |
| |
| def setup_context(ctx, inputs, keyword_only_inputs, output): |
| assert len(inputs) == 2 |
| assert inputs[1] == 2 |
| assert keyword_only_inputs == {"y": 3, "z": 42} |
| ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1] |
| |
| torch.library.register_autograd( |
| "_torch_testing::foo", backward, setup_context=setup_context, lib=lib |
| ) |
| |
| w = torch.randn(3, requires_grad=True) |
| torch.ops._torch_testing.foo(w, z=42).sum().backward() |
| self.assertTrue(called) |
| self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42)) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_manual_schema_error(self): |
| with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"): |
| |
| @torch.library.custom_op( |
| "_torch_testing::sin_", |
| mutates_args=(), |
| schema="(Tensor(a!) x) -> ()", |
| ) |
| def sin_(x): |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| |
| def test_supports_tensorlist(self): |
| @torch._library.autograd.supports_tensorlist |
| class Stack(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, xs): |
| ctx.num_xs = len(xs) |
| return torch.stack(xs) |
| |
| @staticmethod |
| def backward(ctx, grad): |
| expected = ([True] * ctx.num_xs,) |
| self.assertEqual(ctx.needs_input_grad, expected) |
| return list(grad.unbind(0)) |
| |
| # call two applys, do a backward on the first |
| def t(): |
| return torch.randn([], requires_grad=True) |
| |
| xs0 = [t(), t(), t()] |
| xs1 = [t(), t(), t(), t()] |
| y0 = Stack.apply(xs0) |
| y1 = Stack.apply(xs1) |
| grads = torch.autograd.grad(y0.sum(), xs0) |
| self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) |
| |
| # call one apply, do multiple backwards |
| xs = [t(), t(), t()] |
| y = Stack.apply(xs) |
| _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) |
| _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) |
| grads = torch.autograd.grad(y.sum(), xs, retain_graph=True) |
| self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) |
| |
| # error: on access forward, backward directly |
| with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"): |
| Stack.forward(None, xs) |
| with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"): |
| Stack.backward(None, xs) |
| |
| # the recursive case |
| @torch._library.autograd.supports_tensorlist |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, xs): |
| if len(xs) > 1: |
| return Foo.apply(xs[1:]) |
| ctx.len_xs = len(xs) |
| return xs[0].sin() |
| |
| @staticmethod |
| def backward(ctx, grad): |
| result = [None] * ctx.len_xs |
| result[-1] = grad.cos() |
| return result |
| |
| # should work |
| result = Foo.apply(xs) |
| expected = xs[-1].sin() |
| self.assertEqual(result, expected) |
| |
| # recursive on backward |
| @torch._library.autograd.supports_tensorlist |
| class Bar(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, xs): |
| return [xs[i] + i for i in range(len(xs))] |
| |
| @staticmethod |
| def backward(ctx, grads): |
| f1 = Bar.apply(grads[:2]) |
| f2 = Bar.apply(grads[2:]) |
| return f1 + f2 |
| |
| xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)] |
| ys = Bar.apply(xs) |
| sum(ys).backward() |
| result = [xi.grad for xi in xs] |
| self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0)) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_default_values(self): |
| defaults = [] |
| |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f( |
| x: Tensor, |
| a: Optional[int] = None, |
| b: float = 3.14, |
| c: bool = True, |
| d: int = 3, |
| e: str = "foo", |
| f: torch.dtype = torch.float, |
| g: torch.dtype = torch.float32, |
| h: torch.dtype = torch.int, |
| i: torch.device = torch.device("cpu:0"), |
| j: torch.device = "cpu", |
| ) -> Tensor: |
| defaults.extend([a, b, c, d, e, f, g, h, i, j]) |
| return x.clone() |
| |
| x = torch.randn(3) |
| f(x) |
| self.assertEqual( |
| defaults, |
| [ |
| None, |
| 3.14, |
| True, |
| 3, |
| "foo", |
| torch.float, |
| torch.float32, |
| torch.int, |
| torch.device("cpu:0"), |
| "cpu", |
| ], |
| ) |
| default_values = [ |
| arg.default_value |
| for arg in torch.ops._torch_testing.f.default._schema.arguments |
| ] |
| # enum values taken from c10/core/ScalarType.h |
| type_enum = { |
| "float": 6, |
| "int": 3, |
| } |
| self.assertEqual( |
| default_values, |
| [ |
| None, |
| None, |
| 3.14, |
| True, |
| 3, |
| "foo", |
| type_enum["float"], |
| type_enum["float"], |
| type_enum["int"], |
| torch.device("cpu:0"), |
| torch.device("cpu"), |
| ], |
| ) |
| |
| def test_mutated_error(self): |
| with self.assertRaisesRegex( |
| ValueError, r".*{'y'} in mutates_args were not found" |
| ): |
| |
| @torch.library.custom_op( |
| "_torch_testing::numpy_sin_inplace", |
| mutates_args={"y"}, |
| device_types="cpu", |
| ) |
| def numpy_sin_inplace(x: Tensor) -> None: |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| |
| def test_mutated(self): |
| @torch.library.custom_op( |
| "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu" |
| ) |
| def numpy_sin_inplace(x: Tensor) -> None: |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| |
| x = torch.randn(3) |
| version = x._version |
| expected = x.sin() |
| numpy_sin_inplace(x) |
| self.assertEqual(x, expected) |
| self.assertGreater(x._version, version) |
| |
| @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"}) |
| def f( |
| x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] |
| ) -> None: |
| return |
| |
| x = torch.randn(3) |
| y = torch.randn(3) |
| z = [torch.randn(3), torch.randn(3)] |
| w = [torch.randn(3), None, torch.randn(3)] |
| initial_versions = pytree.tree_map_only( |
| torch.Tensor, lambda x: x._version, (x, y, z, w) |
| ) |
| f(x, y, z, w) |
| new_versions = pytree.tree_map_only( |
| torch.Tensor, lambda x: x._version, (x, y, z, w) |
| ) |
| |
| self.assertEqual(initial_versions[0], new_versions[0]) |
| initial_versions, _ = pytree.tree_flatten(initial_versions[1:]) |
| new_versions, _ = pytree.tree_flatten(new_versions[1:]) |
| for prev, after in zip(initial_versions, new_versions): |
| if prev is None and after is None: |
| continue |
| self.assertGreater(after, prev) |
| |
| def test_mutated_unknown(self): |
| @torch.library.custom_op( |
| "_torch_testing::f", mutates_args="unknown", device_types="cpu" |
| ) |
| def f(x: Tensor) -> None: |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| |
| x = torch.randn(3) |
| version = x._version |
| expected = x.sin() |
| f(x) |
| self.assertEqual(x, expected) |
| self.assertGreater(x._version, version) |
| |
| @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown") |
| def f2( |
| x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] |
| ) -> None: |
| return |
| |
| x = torch.randn(3) |
| y = torch.randn(3) |
| z = [torch.randn(3), torch.randn(3)] |
| w = [torch.randn(3), None, torch.randn(3)] |
| initial_versions = pytree.tree_map_only( |
| torch.Tensor, lambda x: x._version, (x, y, z, w) |
| ) |
| f2(x, y, z, w) |
| new_versions = pytree.tree_map_only( |
| torch.Tensor, lambda x: x._version, (x, y, z, w) |
| ) |
| |
| initial_versions, _ = pytree.tree_flatten(initial_versions) |
| new_versions, _ = pytree.tree_flatten(new_versions) |
| for prev, after in zip(initial_versions, new_versions): |
| if prev is None and after is None: |
| continue |
| self.assertGreater(after, prev) |
| |
| with self.assertRaisesRegex(ValueError, "string"): |
| |
| @torch.library.custom_op("_torch_testing::f3", mutates_args="x") |
| def f3(x: Tensor) -> None: |
| return |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_torch_dispatch_rule_subclass(self): |
| from torch.testing._internal.two_tensor import TwoTensor |
| |
| @torch.library.custom_op("mylib::foo", mutates_args={}) |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return x.sin() |
| |
| x = torch.randn(3) |
| y = torch.randn(3) |
| z = TwoTensor(x, y) |
| |
| with torch.library._scoped_library("mylib", "FRAGMENT") as m: |
| called = 0 |
| |
| def TwoTensor_foo(cls, func, types, args, kwargs): |
| nonlocal called |
| assert cls is TwoTensor |
| called += 1 |
| return x.sin() |
| |
| m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo) |
| |
| out = f(z) |
| out2 = z.cos() |
| |
| self.assertEqual(called, 1) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_torch_dispatch_rule_mode(self): |
| from torch.testing._internal.two_tensor import TwoTensorMode |
| |
| @torch.library.custom_op("mylib::foo", mutates_args={}) |
| def f(x: torch.Tensor) -> torch.Tensor: |
| return x.sin() |
| |
| x = torch.randn(3) |
| |
| with torch.library._scoped_library("mylib", "FRAGMENT") as m: |
| called = 0 |
| |
| def TwoTensor_foo(mode, func, types, args, kwargs): |
| nonlocal called |
| called += 1 |
| return x.sin() |
| |
| m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo) |
| |
| with TwoTensorMode(): |
| out = f(x) |
| out2 = x.cos() |
| |
| self.assertEqual(called, 1) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| @parametrize("idx", [0, 1, 2, 3, 4, 5]) |
| def test_library_register_fake_source(self, idx): |
| opname = f"source{idx}" |
| op = getattr(torch.ops._torch_testing, opname).default |
| entry = torch._library.simple_registry.singleton.find(op._name) |
| source = entry.fake_impl.kernel.source |
| assert source is not None |
| self.assertTrue("custom_op_db.py" in source) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_fake(self): |
| for mode in ["function", "qualname", "opoverload"]: |
| |
| @torch.library.custom_op("_torch_testing::add", mutates_args=()) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.cpu().numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| called = False |
| |
| if mode == "function": |
| dec = torch.library.register_fake(add) |
| self.assertIsNotNone(dec) |
| elif mode == "qualname": |
| dec = torch.library.register_fake("_torch_testing::add") |
| self.assertIsNotNone(dec) |
| elif mode == "opoverload": |
| dec = torch.library.register_fake(torch.ops._torch_testing.add.default) |
| self.assertIsNotNone(dec) |
| else: |
| raise AssertionError("should not get here") |
| |
| @dec |
| def _(x, y): |
| nonlocal called |
| called = True |
| return torch.empty_like(x) |
| |
| with torch._subclasses.fake_tensor.FakeTensorMode(): |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z.shape, x.shape) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_torch_dispatch(self): |
| for mode in ["function", "qualname", "opoverload"]: |
| |
| class MyMode(torch.utils._python_dispatch.TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| return func(*args, **kwargs) |
| |
| @torch.library.custom_op("_torch_testing::add", mutates_args=()) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.cpu().numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| called = False |
| |
| if mode == "function": |
| dec = torch.library.register_torch_dispatch(add, MyMode) |
| self.assertIsNotNone(dec) |
| elif mode == "qualname": |
| dec = torch.library.register_torch_dispatch( |
| "_torch_testing::add", MyMode |
| ) |
| self.assertIsNotNone(dec) |
| elif mode == "opoverload": |
| dec = torch.library.register_torch_dispatch( |
| torch.ops._torch_testing.add.default, MyMode |
| ) |
| self.assertIsNotNone(dec) |
| else: |
| raise AssertionError("should not get here") |
| |
| @dec |
| def _(mode, func, types, args, kwargs): |
| nonlocal called |
| called = True |
| return func(*args, **kwargs) |
| |
| with MyMode(): |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z.shape, x.shape) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_torch_dispatch_low_level(self): |
| modes = ["qualname", "opoverload"] |
| calls = ["decorator", "function"] |
| device_types_options = [("cpu", "cuda"), "cpu", None] |
| |
| for mode, call, device_types in itertools.product( |
| modes, calls, device_types_options |
| ): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("add10(Tensor x, float y) -> Tensor") |
| |
| if mode == "qualname": |
| op = "_torch_testing::add10" |
| else: |
| assert mode == "opoverload" |
| op = torch.ops._torch_testing.add10.default |
| |
| called = False |
| |
| class MyMode(torch.utils._python_dispatch.TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| return func(*args, **kwargs) |
| |
| if call == "decorator": |
| |
| @torch.library.register_torch_dispatch(op, MyMode, lib=lib) |
| def _(mode, func, types, args, kwargs): |
| x, y = args |
| nonlocal called |
| called = True |
| return x + y |
| |
| else: |
| assert call == "function" |
| |
| def add_stuff(mode, func, types, args, kwargs): |
| x, y = args |
| nonlocal called |
| called = True |
| return x + y |
| |
| torch.library.register_torch_dispatch( |
| op, MyMode, add_stuff, lib=lib |
| ) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| with MyMode(): |
| z = torch.ops._torch_testing.add10.default(x, y) |
| self.assertEqual(z, x + y) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_kernel(self): |
| modes = ["function", "qualname", "opoverload"] |
| calls = ["decorator", "function"] |
| device_types_options = ["cpu", None] |
| |
| for mode, call, device_types in itertools.product( |
| modes, calls, device_types_options |
| ): |
| |
| @torch.library.custom_op( |
| "_torch_testing::add", mutates_args=(), device_types="cuda" |
| ) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.cpu().numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| if mode == "function": |
| op = add |
| elif mode == "qualname": |
| op = "_torch_testing::add" |
| else: |
| assert mode == "opoverload" |
| op = torch.ops._torch_testing.add.default |
| |
| called = False |
| |
| if call == "decorator": |
| |
| @torch.library.register_kernel(op, device_types) |
| def _(x, y): |
| nonlocal called |
| called = True |
| x_np = x.numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np) |
| |
| else: |
| assert call == "function" |
| |
| def add_cpu(x, y): |
| nonlocal called |
| called = True |
| x_np = x.numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np) |
| |
| torch.library.register_kernel(op, device_types, add_cpu) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z, x + y) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_kernel_low_level(self): |
| modes = ["qualname", "opoverload"] |
| calls = ["decorator", "function"] |
| device_types_options = [("cpu", "cuda"), "cpu", None] |
| |
| for mode, call, device_types in itertools.product( |
| modes, calls, device_types_options |
| ): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("add9(Tensor x, float y) -> Tensor") |
| |
| if mode == "qualname": |
| op = "_torch_testing::add9" |
| else: |
| assert mode == "opoverload" |
| op = torch.ops._torch_testing.add9.default |
| |
| called = False |
| |
| if call == "decorator": |
| |
| @torch.library.register_kernel(op, device_types, lib=lib) |
| def _(x, y): |
| nonlocal called |
| called = True |
| x_np = x.numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np) |
| |
| else: |
| assert call == "function" |
| |
| def add_cpu(x, y): |
| nonlocal called |
| called = True |
| x_np = x.numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np) |
| |
| torch.library.register_kernel(op, device_types, add_cpu, lib=lib) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| z = torch.ops._torch_testing.add9.default(x, y) |
| self.assertEqual(z, x + y) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_autograd(self): |
| for mode in ["function", "qualname", "opoverload"]: |
| |
| @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) |
| def numpy_sin(x: Tensor) -> Tensor: |
| x_np = x.cpu().numpy() |
| y_np = np.sin(x_np) |
| return torch.from_numpy(y_np).to(device=x.device) |
| |
| def setup_context(ctx, inputs, output) -> Tensor: |
| (x,) = inputs |
| ctx.save_for_backward(x) |
| |
| called = False |
| |
| def backward(ctx, grad): |
| nonlocal called |
| called = True |
| (x,) = ctx.saved_tensors |
| return grad * x.cos() |
| |
| if mode == "function": |
| torch.library.register_autograd( |
| numpy_sin, backward, setup_context=setup_context |
| ) |
| elif mode == "qualname": |
| torch.library.register_autograd( |
| "mylib::numpy_sin", backward, setup_context=setup_context |
| ) |
| elif mode == "opoverload": |
| torch.library.register_autograd( |
| torch.ops.mylib.numpy_sin.default, |
| backward, |
| setup_context=setup_context, |
| ) |
| |
| x = torch.randn(3, requires_grad=True) |
| y = numpy_sin(x) |
| (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
| self.assertTrue(called) |
| self.assertEqual(grad_x, x.cos()) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_autograd_low_level(self): |
| for mode in ["qualname", "opoverload"]: |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("sin5(Tensor x) -> Tensor") |
| |
| def numpy_sin(x: Tensor) -> Tensor: |
| x_np = x.cpu().detach().numpy() |
| y_np = np.sin(x_np) |
| return torch.from_numpy(y_np).to(device=x.device) |
| |
| def setup_context(ctx, inputs, output) -> Tensor: |
| (x,) = inputs |
| ctx.save_for_backward(x) |
| |
| called = False |
| |
| def backward(ctx, grad): |
| nonlocal called |
| called = True |
| (x,) = ctx.saved_tensors |
| return grad * x.cos() |
| |
| lib.impl("sin5", numpy_sin, "CPU") |
| |
| called = False |
| |
| if mode == "qualname": |
| torch.library.register_autograd( |
| "_torch_testing::sin5", |
| backward, |
| setup_context=setup_context, |
| lib=lib, |
| ) |
| elif mode == "opoverload": |
| torch.library.register_autograd( |
| torch.ops._torch_testing.sin5.default, |
| backward, |
| setup_context=setup_context, |
| lib=lib, |
| ) |
| x = torch.randn(3, requires_grad=True) |
| y = torch.ops._torch_testing.sin5(x) |
| (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
| self.assertTrue(called) |
| self.assertEqual(grad_x, x.cos()) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_fake(self): |
| @torch.library.custom_op("_torch_testing::add", mutates_args=()) |
| def add(x: Tensor, y: float) -> Tensor: |
| x_np = x.cpu().numpy() |
| out_np = x_np + y |
| return torch.from_numpy(out_np).to(x.device) |
| |
| x = torch.randn(3) |
| y = 3.14 |
| z = add(x, y) |
| self.assertEqual(z, x + y) |
| |
| try: |
| with torch._subclasses.fake_tensor.FakeTensorMode(): |
| x = torch.randn(3) |
| add(x, y) |
| raise AssertionError("should not be hit") |
| except RuntimeError as e: |
| abstract_impl_error_msg = str(e) |
| abstract_impl_error_msg = re.sub( |
| r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg |
| ).replace(". ", ".\n") |
| self.assertExpectedInline( |
| abstract_impl_error_msg, |
| """\ |
| There was no fake impl registered for <CustomOpDef(_torch_testing::add)>. |
| This is necessary for torch.compile/export/fx tracing to work. |
| Please use `add.register_fake` to add an fake impl.""", |
| ) |
| |
| if not IS_WINDOWS: |
| |
| @torch.compile(backend="eager") |
| def f(x, y): |
| return add(x, y) |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, "no fake impl"): |
| f(x, y) |
| |
| abstract_called = False |
| |
| @add.register_fake |
| def _(x, y): |
| nonlocal abstract_called |
| abstract_called = True |
| return torch.empty_like(x) |
| |
| with torch._subclasses.fake_tensor.FakeTensorMode(): |
| x = torch.randn(3) |
| z = add(x, y) |
| self.assertEqual(z.shape, x.shape) |
| self.assertTrue(abstract_called) |
| |
| @skipIfTorchDynamo("recursive dynamo") |
| @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") |
| def test_compile(self): |
| called_impl = False |
| called_abstract = False |
| |
| @torch.library.custom_op("_torch_testing::linear", mutates_args=()) |
| def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: |
| nonlocal called_impl |
| called_impl = True |
| x_np = x.numpy() |
| w_np = weight.numpy() |
| b_np = bias.numpy() |
| out_np = np.add(x_np @ w_np.T, bias) |
| return out_np |
| |
| @custom_linear.register_fake |
| def _(x, weight, bias): |
| nonlocal called_abstract |
| called_abstract = True |
| assert x.dim() == 2 |
| assert weight.dim() == 2 |
| assert bias.dim() == 1 |
| assert x.shape[1] == weight.shape[1] |
| assert weight.shape[0] == bias.shape[0] |
| assert x.device == weight.device |
| return x.new_empty(x.size(0), weight.size(0)) |
| |
| x = torch.randn(2, 2) |
| weight = torch.randn(2, 2) |
| bias = torch.randn(2) |
| out = torch.compile(custom_linear, backend="eager", fullgraph=True)( |
| x, weight, bias |
| ) |
| self.assertEqual(out, torch.nn.functional.linear(x, weight, bias)) |
| self.assertTrue(called_impl) |
| self.assertTrue(called_abstract) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_register_autograd_error_cases(self): |
| @torch.library.custom_op("_torch_testing::g", mutates_args=()) |
| def g(x: Tensor) -> Tensor: |
| return x.sin() |
| |
| x = torch.randn(3, requires_grad=True) |
| y = g(x) |
| with self.assertRaisesRegex(RuntimeError, "no autograd formula"): |
| y.sum().backward() |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_replacement(self): |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| return x.sin() |
| |
| x = torch.randn(3) |
| y = f(x) |
| self.assertEqual(y, x.sin()) |
| |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| return x.cos() |
| |
| y = f(x) |
| self.assertEqual(y, x.cos()) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_split_device(self): |
| cpu_call_count = 0 |
| cuda_call_count = 0 |
| |
| @torch.library.custom_op( |
| "_torch_testing::f", mutates_args=(), device_types="cpu" |
| ) |
| def f(x: Tensor) -> Tensor: |
| nonlocal cpu_call_count |
| cpu_call_count += 1 |
| x_np = x.numpy() |
| out_np = np.sin(x_np) |
| return torch.from_numpy(out_np) |
| |
| @f.register_kernel("cuda") |
| def _(x: Tensor) -> Tensor: |
| nonlocal cuda_call_count |
| cuda_call_count += 1 |
| x_np = x.cpu().numpy() |
| out_np = np.sin(x_np) |
| return torch.from_numpy(out_np).to(x.device) |
| |
| x = torch.randn(3) |
| y = f(x) |
| self.assertEqual(y, x.sin()) |
| self.assertEqual(cpu_call_count, 1) |
| self.assertEqual(cuda_call_count, 0) |
| |
| x = x.cuda() |
| y = f(x) |
| self.assertEqual(y, x.sin()) |
| self.assertEqual(cpu_call_count, 1) |
| self.assertEqual(cuda_call_count, 1) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| @unittest.skipIf(not TEST_CUDA, "requires CUDA") |
| def test_multi_types(self): |
| @torch.library.custom_op( |
| "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda") |
| ) |
| def f(x: Tensor) -> Tensor: |
| x_np = x.cpu().numpy() |
| out_np = np.sin(x_np) |
| return torch.from_numpy(out_np).to(x.device) |
| |
| x = torch.randn(3) |
| y = f(x) |
| self.assertEqual(y, x.sin()) |
| x = x.cuda() |
| y = f(x) |
| self.assertEqual(y, x.sin()) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_overloading(self): |
| called_f = 0 |
| called_f1 = 0 |
| |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| nonlocal called_f |
| called_f += 1 |
| return x.clone() |
| |
| x = torch.randn(2, 3) |
| torch.ops._torch_testing.f(x) |
| self.assertEqual(called_f, 1) |
| |
| @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) |
| def f1(x: Tensor, y: Tensor) -> Tensor: |
| nonlocal called_f1 |
| called_f1 += 1 |
| return x.clone() |
| |
| torch.ops._torch_testing.f(x, x) |
| self.assertEqual(called_f1, 1) |
| |
| def test_disallows_output_aliasing(self): |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| return x.view(-1) |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, "may not alias"): |
| f(x) |
| |
| @torch.library.custom_op("_torch_testing::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| return x |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, "may not alias"): |
| f(x) |
| |
| @torch.library.custom_op( |
| "_torch_testing::f", mutates_args={"x"}, device_types="cpu" |
| ) |
| def numpy_sin_inplace(x: Tensor) -> Tensor: |
| x_np = x.numpy() |
| np.sin(x_np, out=x_np) |
| return x |
| |
| x = torch.randn(3) |
| with self.assertRaisesRegex(RuntimeError, "may not alias"): |
| numpy_sin_inplace(x) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_factory_function(self): |
| @torch.library.custom_op( |
| "_torch_testing::f", mutates_args={}, device_types="cpu" |
| ) |
| def f(device: torch.device) -> Tensor: |
| return torch.ones(3) |
| |
| result = f(device="cpu") |
| self.assertEqual(result.device, torch.device("cpu")) |
| self.assertEqual(result, torch.ones(3)) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, "f does not have a kernel registered for cuda" |
| ): |
| f("cuda") |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| "Functions without tensor inputs are required to have a `device: torch.device` argument", |
| ): |
| |
| @torch.library.custom_op( |
| "_torch_testing::f2", mutates_args={}, device_types="cpu" |
| ) |
| def f2() -> Tensor: |
| return torch.ones(3) |
| |
| @torch.library.custom_op("_torch_testing::f3", mutates_args={}) |
| def f3() -> Tensor: |
| raise NotImplementedError("NYI") |
| |
| with self.assertRaisesRegex( |
| ValueError, |
| "Functions without tensor inputs are required to have a `device: torch.device` argument", |
| ): |
| |
| @f3.register_kernel("cpu") |
| def _(): |
| return torch.zeros(3) |
| |
| result = f(x) |
| |
| @torch.library.custom_op("_torch_testing::f4", mutates_args={}) |
| def f4(device: torch.device) -> Tensor: |
| raise NotImplementedError("NYI") |
| |
| @f4.register_kernel("cpu") |
| def _(device: torch.device): |
| return torch.zeros(3) |
| |
| result = f(device="cpu") |
| self.assertEqual(result.device, torch.device("cpu")) |
| self.assertEqual(result, torch.ones(3)) |
| |
| def test_library_schema_infer(self): |
| def foo_impl(x: torch.Tensor) -> torch.Tensor: |
| return x.sin() |
| |
| schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) |
| self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") |
| |
| schema = torch.library.infer_schema(foo_impl, mutates_args={}) |
| self.assertExpectedInline(schema, "(Tensor x) -> Tensor") |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_set_kernel_enabled(self): |
| x = torch.ones(1) |
| |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor) -> Tensor: |
| return x + 1 |
| |
| self.assertEqual(f(x), x + 1) |
| with self.assertLogs("torch._library.custom_ops") as captured: |
| with f.set_kernel_enabled("gpu", enabled=False): |
| self.assertEqual(f(x), x + 1) |
| self.assertIn( |
| "no kernel was registered for this device type", captured.output[0] |
| ) |
| |
| @f.register_kernel("cpu") |
| def _(x): |
| return x + 2 |
| |
| self.assertEqual(f(x), x + 2) |
| |
| with self.assertLogs("torch._library.custom_ops") as captured: |
| with f.set_kernel_enabled("cpu", enabled=True): |
| self.assertEqual(f(x), x + 2) |
| self.assertIn("already enabled", captured.output[0]) |
| |
| with f.set_kernel_enabled("cpu", enabled=False): |
| self.assertEqual(f(x), x + 1) |
| |
| with self.assertLogs("torch._library.custom_ops") as captured: |
| with f.set_kernel_enabled("cpu", enabled=False): |
| self.assertEqual(f(x), x + 1) |
| self.assertIn("already disabled", captured.output[0]) |
| |
| self.assertEqual(f(x), x + 1) |
| |
| with f.set_kernel_enabled("cpu", enabled=True): |
| self.assertEqual(f(x), x + 2) |
| |
| with f.set_kernel_enabled("cpu", enabled=False): |
| self.assertEqual(f(x), x + 1) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_register_vmap_kwargonly_low_level(self): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("foo(Tensor x, *, float y) -> Tensor") |
| called = False |
| |
| def foo_impl(x, *, y): |
| return x * y |
| |
| lib.impl("foo", foo_impl, "CPU") |
| |
| def vmap(info, in_dims, x, *, y): |
| nonlocal called |
| called = True |
| return x * y, 0 |
| |
| torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) |
| |
| x = torch.ones(3) |
| result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14) |
| self.assertTrue(called) |
| self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14])) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_register_vmap_defaults(self): |
| with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: |
| lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") |
| |
| def foo_impl(w, x=2, *, y=3, z): |
| return w * x * y * z |
| |
| lib.impl("foo", foo_impl, "CPU") |
| |
| called = False |
| |
| def vmap(info, in_dims, w, x=2, *, y=3, z): |
| nonlocal called |
| called = True |
| return w * x * y * z, 0 |
| |
| torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) |
| |
| w = torch.ones(3) |
| result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42) |
| self.assertTrue(called) |
| self.assertEqual(result, w * 2 * 3 * 42) |
| |
| def test_layout_constraint_tags(self): |
| needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order |
| flexible_layout = torch._C.Tag.flexible_layout |
| # (tags, the result of the tag inference) |
| tests = [ |
| ({needs_fixed_stride_order}, needs_fixed_stride_order), |
| ({flexible_layout}, flexible_layout), |
| # If no tags are provided, then the following is the default |
| (set(), flexible_layout), |
| # If multiple tags are provided, then we use the most constrained tag. |
| ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), |
| ] |
| from torch._inductor.lowering import get_layout_constraint_tag |
| |
| for tags, expected in tests: |
| with torch.library._scoped_library("mylib", "FRAGMENT") as m: |
| m.define("foobar(Tensor x) -> Tensor", tags=tags) |
| result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) |
| self.assertEqual(result, expected) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_vmap(self): |
| for mode in ["function", "qualname", "opoverload", "c_opdef"]: |
| |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor, y: Tensor) -> Tensor: |
| return x * y |
| |
| called = False |
| |
| def fvmap(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x * y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| if mode == "function": |
| torch.library.register_vmap(f, fvmap) |
| elif mode == "qualname": |
| torch.library.register_vmap("mylib::f", fvmap) |
| elif mode == "opoverload": |
| torch.library.register_vmap(torch.ops.mylib.f.default, fvmap) |
| elif mode == "c_opdef": |
| f.register_vmap(fvmap) |
| |
| x = torch.randn(2, 2) |
| y = torch.randn(2, 2) |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x * y) |
| |
| called = False |
| result = torch.vmap(f, out_dims=1)(x, y) |
| self.assertEqual(result, (x * y).T) |
| self.assertTrue(called) |
| |
| called = False |
| result = torch.vmap(f, in_dims=1)(x, y) |
| self.assertEqual(result, (x * y).T) |
| self.assertTrue(called) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_vmap_library_decorator(self): |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor, y: Tensor) -> Tensor: |
| return x * y |
| |
| called = False |
| |
| @torch.library.register_vmap("mylib::f") |
| def fvmap(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x * y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| x = torch.randn(2, 2) |
| y = torch.randn(2, 2) |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x * y) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_vmap_op_decorator(self): |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor, y: Tensor) -> Tensor: |
| return x * y |
| |
| called = False |
| |
| @f.register_vmap |
| def fvmap(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x * y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| x = torch.randn(2, 2) |
| y = torch.randn(2, 2) |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x * y) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_vmap_register_multiple_times(self): |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor, y: Tensor) -> Tensor: |
| return x * y |
| |
| called = False |
| |
| @f.register_vmap |
| def fvmap(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x * y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| x = torch.randn(2, 2) |
| y = torch.randn(2, 2) |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x * y) |
| called = False |
| |
| @f.register_vmap |
| def fvmap2(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x + y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x + y) |
| |
| @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") |
| def test_library_register_vmap_register_multiple_times_2(self): |
| @torch.library.custom_op("mylib::f", mutates_args=()) |
| def f(x: Tensor, y: Tensor) -> Tensor: |
| return x * y |
| |
| called = False |
| |
| @torch.library.register_vmap("mylib::f") |
| def fvmap(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x * y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| x = torch.randn(2, 2) |
| y = torch.randn(2, 2) |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x * y) |
| called = False |
| |
| @torch.library.register_vmap("mylib::f") |
| def fvmap2(info, in_dims, x, y): |
| nonlocal called |
| called = True |
| x_bdim, y_bdim = in_dims |
| x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| result = x + y |
| result = result.movedim(-1, 0) |
| return result, 0 |
| |
| result = torch.vmap(f)(x, y) |
| self.assertTrue(called) |
| self.assertEqual(result, x + y) |
| |
| |
| class MiniOpTestOther(CustomOpTestCaseBase): |
| test_ns = "mini_op_test" |
| |
| def test_nonzero_again(self): |
| x = torch.tensor([0, 1, 2, 0, 0]) |
| y = torch.ops.aten.nonzero.default(x) |
| self.assertEqual(y, torch.tensor([[1], [2]])) |
| |
| |
| optests.generate_opcheck_tests( |
| MiniOpTest, |
| ["aten", "mini_op_test"], |
| get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), |
| additional_decorators={ |
| "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure] |
| }, |
| test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, |
| ) |
| |
| optests.generate_opcheck_tests( |
| MiniOpTestOther, |
| ["aten", "mini_op_test"], |
| get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), |
| test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, |
| ) |
| |
| |
| class TestGenerateOpcheckTests(CustomOpTestCaseBase): |
| def test_MiniOpTest(self): |
| for orig_test in ["test_mm", "test_nonzero"]: |
| for ( |
| test |
| ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS: |
| expected_test = f"{test}__{orig_test}" |
| self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test) |
| |
| def test_generate_repro_save_data(self): |
| from torch.testing._internal.optests.generate_tests import generate_repro |
| |
| args = (torch.ones(2, 2),) |
| kwargs = {"mat2": torch.zeros(2, 2)} |
| actual = generate_repro( |
| "test_schema", |
| torch.ops.aten.sin.default, |
| args, |
| kwargs, |
| save_data=True, |
| dry_run=True, |
| ) |
| actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual) |
| self.assertExpectedInline( |
| actual, |
| """\ |
| # ========================================================= |
| # BEGIN REPRO SCRIPT |
| # ========================================================= |
| import torch |
| from torch.testing._internal.optests import opcheck |
| |
| # Make sure you have loaded the library that contains the op |
| # via an import or torch.ops.load_library(...) |
| op = torch.ops.aten.sin.default |
| |
| args, kwargs = torch.load("repro.pt") |
| opcheck(op, args, kwargs, test_utils="test_schema") |
| # ========================================================= |
| # END REPRO SCRIPT |
| # ========================================================= |
| """, |
| ) |
| |
| def test_generate_repro_no_save_data(self): |
| from torch.testing._internal.optests.generate_tests import generate_repro |
| |
| args = (torch.ones(2, 2),) |
| kwargs = {"mat2": torch.zeros(2, 2)} |
| actual = generate_repro( |
| "test_schema", |
| torch.ops.aten.sin.default, |
| args, |
| kwargs, |
| save_data=False, |
| dry_run=True, |
| ) |
| self.assertExpectedInline( |
| actual, |
| """\ |
| # ========================================================= |
| # BEGIN REPRO SCRIPT |
| # ========================================================= |
| import torch |
| from torch.testing._internal.optests import opcheck |
| |
| # Make sure you have loaded the library that contains the op |
| # via an import or torch.ops.load_library(...) |
| op = torch.ops.aten.sin.default |
| |
| # If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 |
| # we will fill them in same (args, kwargs) as in your test |
| args = () # args to the operator |
| kwargs = {} # kwargs to the operator |
| opcheck(op, args, kwargs, test_utils="test_schema") |
| # ========================================================= |
| # END REPRO SCRIPT |
| # ========================================================= |
| """, |
| ) |
| |
| def test_failures_dict_validation(self): |
| from torch.testing._internal.optests.generate_tests import ( |
| FailuresDict, |
| validate_failures_dict_structure, |
| ) |
| |
| failures = { |
| "mini_op_test::incorrect_schema": { |
| "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": { |
| "comment": "", |
| "status": "success", |
| } |
| } |
| } |
| with self.assertRaisesRegex(RuntimeError, "got status=success"): |
| validate_failures_dict_structure( |
| FailuresDict("", failures), |
| torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, |
| MiniOpTest, |
| ) |
| |
| failures = { |
| "mini_op_test::incorrect_schema": { |
| "MiniOpTest.test_aot_dispatch__test_delayed_error": { |
| "comment": "", |
| "status": "xfail", |
| }, |
| } |
| } |
| with self.assertRaisesRegex(RuntimeError, "should begin with one of"): |
| validate_failures_dict_structure( |
| FailuresDict("", failures), |
| torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, |
| MiniOpTest, |
| ) |
| |
| failures = { |
| "mini_op_test::incorrect_schema": { |
| "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": { |
| "comment": "", |
| "status": "xfail", |
| }, |
| } |
| } |
| with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"): |
| validate_failures_dict_structure( |
| FailuresDict("", failures), |
| torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, |
| MiniOpTest, |
| ) |
| |
| def test_dont_generate_decorator(self): |
| self.assertTrue(hasattr(MiniOpTest, "test_dont_generate")) |
| self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate")) |
| |
| def test_opcheck(self): |
| x = torch.randn(3, requires_grad=True) |
| with self.assertRaisesRegex(ValueError, "OpOverload"): |
| torch.library.opcheck(torch.sin, (x,)) |
| with self.assertRaisesRegex(ValueError, "test_utils to be subset of"): |
| torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah") |
| result = torch.library.opcheck(torch.ops.aten.sin.default, (x,)) |
| |
| self.assertEqual( |
| result, |
| { |
| "test_schema": "SUCCESS", |
| "test_autograd_registration": "SUCCESS", |
| "test_faketensor": "SUCCESS", |
| "test_aot_dispatch_dynamic": "SUCCESS", |
| }, |
| ) |
| |
| result = torch.library.opcheck( |
| torch.ops.aten.sin.default, (x,), test_utils="test_schema" |
| ) |
| self.assertEqual(result, {"test_schema": "SUCCESS"}) |
| |
| result = torch.library.opcheck( |
| torch.ops.aten.sin.default, |
| (x,), |
| test_utils=["test_schema", "test_faketensor"], |
| ) |
| self.assertEqual( |
| result, |
| { |
| "test_schema": "SUCCESS", |
| "test_faketensor": "SUCCESS", |
| }, |
| ) |
| |
| def test_opcheck_customopdef(self): |
| sample_inputs = [ |
| (torch.randn(3),), |
| (torch.randn(3, requires_grad=True),), |
| ] |
| if torch.cuda.is_available(): |
| sample_inputs.extend( |
| [ |
| (torch.randn(3, device="cuda"),), |
| (torch.randn(3, device="cuda", requires_grad=True),), |
| ] |
| ) |
| for args in sample_inputs: |
| torch.library.opcheck(custom_op_db.numpy_cube, args) |
| |
| def test_is_inside_opcheck_mode(self): |
| self.assertFalse(optests.is_inside_opcheck_mode()) |
| with optests.generate_tests.OpCheckMode( |
| ["foo"], "bar", lambda x: x, None, "baz", "brr" |
| ): |
| self.assertTrue(optests.is_inside_opcheck_mode()) |
| |
| def test_opcheck_bad_op(self): |
| op = op_with_incorrect_schema(self, "foo") |
| x = torch.randn(3) |
| with self.assertRaisesRegex(Exception, "is not defined to alias output"): |
| torch.library.opcheck(op, (x,)) |
| |
| result = torch.library.opcheck(op, (x,), raise_exception=False) |
| self.assertTrue(isinstance(result["test_schema"], RuntimeError)) |
| del result["test_schema"] |
| self.assertEqual( |
| result, |
| { |
| "test_autograd_registration": "SUCCESS", |
| "test_faketensor": "SUCCESS", |
| "test_aot_dispatch_dynamic": "SUCCESS", |
| }, |
| ) |
| |
| def test_opcheck_does_not_require_extra_deps(self): |
| # torch.testing._internal.common_utils comes with a lot of additional |
| # test-time dependencies. Since opcheck is public API, it should be |
| # usable only with pytorch install-time dependencies. |
| cmd = [ |
| sys.executable, |
| "-c", |
| "import torch; import sys; \ |
| x = torch.randn(3, requires_grad=True); \ |
| torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ |
| assert 'expecttest' not in sys.modules; \ |
| assert 'torch.testing._internal.common_utils' not in sys.modules", |
| ] |
| subprocess.check_output(cmd, shell=False) |
| |
| |
| class TestTypeConversion(TestCase): |
| """In infer_schema(), we try to suggest a correct type when the type annotation is wrong.""" |
| |
| def setUp(self): |
| self.supported_base_types = [ |
| int, |
| float, |
| bool, |
| str, |
| torch.device, |
| torch.Tensor, |
| torch.dtype, |
| torch.types.Number, |
| ] |
| |
| def test_simple_tuple(self): |
| self.assertEqual(List, tuple_to_list(Tuple)) |
| |
| def test_supported_types(self): |
| for t in self.supported_base_types: |
| result_type = tuple_to_list(Tuple[t, t, t]) |
| self.assertEqual(result_type, List[t]) |
| |
| result_type = tuple_to_list(Tuple[t]) |
| self.assertEqual(result_type, List[t]) |
| |
| def test_optional(self): |
| for t in self.supported_base_types: |
| result_type = tuple_to_list(Tuple[t, Optional[t]]) |
| self.assertEqual(result_type, List[Optional[t]]) |
| |
| result_type = tuple_to_list(Tuple[t, t, Optional[t]]) |
| self.assertEqual(result_type, List[Optional[t]]) |
| |
| result_type = tuple_to_list(Tuple[t, ...]) |
| self.assertEqual(result_type, List[t]) |
| |
| def test_mixed_types(self): |
| result_type = tuple_to_list(Tuple[int, float]) |
| self.assertEqual(result_type, List[typing.Union[int, float]]) |
| |
| result_type = tuple_to_list(Tuple[int, float, str]) |
| self.assertEqual(result_type, List[typing.Union[int, float, str]]) |
| |
| |
| only_for = ("cpu", "cuda") |
| instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) |
| instantiate_parametrized_tests(TestCustomOp) |
| instantiate_parametrized_tests(TestCustomOpAPI) |
| |
| if __name__ == "__main__": |
| run_tests() |