blob: 7c938e302118c383027d2b9f6c5eb258272d992c [file] [log] [blame] [edit]
# Owner(s): ["module: __torch_dispatch__"]
import logging
import sys
import tempfile
import unittest
from copy import deepcopy
import torch
import torch._dynamo
from torch import SymInt
from torch._C import DispatchKey, DispatchKeySet
from torch._custom_op.functional import register_functional_op
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.jiterator import _create_jit_fn
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.library import _scoped_library, fallthrough_kernel, impl, Library
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
first_sample,
IS_WINDOWS,
run_tests,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.logging_tensor import (
capture_logs,
capture_logs_with_logging_tensor_mode,
log_input,
LoggingTensor,
LoggingTensorMode,
LoggingTensorReentrant,
)
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import all_same_mode, no_dispatch
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_get_current_dispatch_mode_stack,
is_in_torch_dispatch_mode,
TorchDispatchMode,
)
from torch.utils._pytree import tree_map, tree_map_only
# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
def _identity(x):
return x
class TestDispatcherPythonBindings(TestCase):
def test_call_boxed(self) -> None:
sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
x = torch.randn(3)
y = torch._C._dispatch_call_boxed(sin, x)
self.assertEqual(y, x.sin())
class TestPythonRegistration(TestCase):
test_ns = "_test_python_registration"
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration
def test_fallback(self) -> None:
test_key = "TESTING_ONLY_GenericMode"
test_keyset = torch._C.DispatchKeySet(test_key)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
expected_op = None
expected_args = None
expected_kwargs = None
# Use this out shape to make sure the result from our fallback
# is what is returned to the user
out_shape = None
def my_fallback(op, *args, **kwargs):
# Disable our handler during checks and generating the output
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
self.assertIs(op, expected_op)
self.assertEqual(args, expected_args)
self.assertEqual(kwargs, expected_kwargs)
# Return something specific
return torch.empty(out_shape)
my_lib.fallback(my_fallback, test_key)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
# Check a factory function
expected_op = torch.ops.aten.empty.memory_format
expected_args = ((2, 2),)
# Extra kwargs to bypass issues with default args in factory functions
expected_kwargs = {
"dtype": torch.float64,
"pin_memory": False,
"device": torch.device("cpu"),
}
out_shape = (3,)
out = torch.empty(*expected_args, **expected_kwargs)
self.assertEqual(out.size(), out_shape)
# Check a regular function
expected_op = torch.ops.aten.add.Tensor
expected_args = (a, b)
expected_kwargs = {}
out_shape = (4,)
out = a + b
self.assertEqual(out.size(), out_shape)
def test_fallback_keyset(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
first_called = False
second_called = False
def first_fallback(keyset, op, *args, **kwargs):
nonlocal first_called
if second_called:
# Recursive call
first_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
else:
# Redispatch down
keyset = keyset.remove(test_key_first)
return op.redispatch(keyset, *args, **kwargs)
def second_fallback(op, *args, **kwargs):
nonlocal second_called
# Set to avoid infinite recursion
second_called = True
# New dispatcher call should hit the first callback again
self.assertFalse(first_called)
a, b = args
# Make a substraction here instead of add !
c = a - b
self.assertTrue(first_called)
return c
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
my_lib.fallback(second_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a - b)
self.assertTrue(first_called)
self.assertTrue(second_called)
def test_fallback_fallthrough(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
is_called = False
def my_fallback(op, *args, **kwargs):
nonlocal is_called
is_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
my_lib.fallback(my_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a + b)
self.assertTrue(is_called)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2:
with _scoped_library("aten", "IMPL") as my_lib1:
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl("neg", my_neg, "AutogradCPU")
self.assertTrue(torch.neg(x).is_neg())
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(
RuntimeError, "operator name does not match namespace"
):
with _scoped_library("foo", "DEF") as my_lib3:
my_lib3.define("neg(Tensor self) -> Tensor")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overridden the behavior for the same before them
with self.assertRaisesRegex(
RuntimeError, "already a kernel registered from python"
):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
self.assertTrue(torch.mul(x, y)._is_zerotensor())
def test_error_if_fn_not_callable(self):
with self.assertRaisesRegex(
TypeError, "Input function is required to be a callable"
):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
def test_finalizer(self):
impls_refcnt = sys.getrefcount(torch.library._impls)
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
lib.define("foo123(Tensor x) -> Tensor")
# 1 for `lib`, 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib), 2)
# We gained an additional reference that gets cleared when the finalizer runs
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
# 1 for `lib`
# 1 for the finalizer
# 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib._op_impls), 3)
def foo123(x):
pass
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
key = f"{self.test_ns}/foo123/CPU"
self.assertTrue(key in torch.library._impls)
saved_op_impls = lib._op_impls
# del will definitely work if the following passes
self.assertEqual(sys.getrefcount(lib), 2)
del lib
# 1 for saved_op_impls
# 1 for sys.getrefcount
# This function should be the last user of lib._op_impls:
# - lib should not have a reference anymore (it was del'ed)
# - lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(saved_op_impls), 2)
self.assertTrue(key not in torch.library._impls)
# lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
def test_override_cpu_sum(self) -> None:
# Example 1
run = [False]
def my_sum(*args, **kwargs):
run[0] = True
return args[0].clone()
with _scoped_library("aten", "IMPL") as my_lib1:
my_lib1.impl("aten::sum", my_sum, "CPU")
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
self.assertTrue(run[0])
# Validate that the old behavior is restored for sum
self.assertEqual(torch.sum(x), torch.tensor(3))
def test_override_cuda_with_jiterator(self) -> None:
def override_where_cuda() -> None:
# Example 1: Invert the behavior of where's condition input
not_where_code_string = """
template <typename T> T inverted_where(bool cond, T a, T b){
return !cond ? a : b;
}
"""
jitted_where = _create_jit_fn(not_where_code_string)
CALLED = [False]
def inverted_where(*args, **kwargs):
CALLED[0] = True
return jitted_where(*args, **kwargs)
# overriding where's cuda kernel with Jiterator generated kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::where.self", inverted_where, "CUDA")
device = "cuda"
cond = torch.tensor(
[True, True, False], device=device, dtype=torch.bool
)
x = torch.tensor([1, 2, 3], device=device)
y = torch.tensor([-1, -2, -3], device=device)
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
def override_gelu_cuda() -> None:
# Example 2: Use relu to approximate gelu for faster compute
fastest_gelu_code_string = """
template <typename T> T fast_gelu(T a){
return a > 0 ? a : 0;
}
"""
jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
CALLED = [False]
def fast_gelu(*args, **kwargs):
CALLED[0] = True
return jitted_gelu(*args, **kwargs)
# overriding gelu's cuda kernel with Jiterator generated relu kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::gelu", fast_gelu, "CUDA")
x = torch.rand([3, 3], device="cuda", dtype=torch.float)
self.assertEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertNotEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
def override_exp_cuda() -> None:
# Example 3: Preventing exp from exploding for float16
clipped_exp_code_string = """
template <typename T> T clipped_exp(T a){
return a > T(10.0) ? T(22026.4657948) : exp(a);
}
"""
jitted_exp = _create_jit_fn(clipped_exp_code_string)
CALLED = [False]
def clipped_exp(*args, **kwargs):
CALLED[0] = True
return jitted_exp(*args, **kwargs)
# overriding exp's cuda kernel with clipped_exp kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::exp", clipped_exp, "CUDA")
x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
self.assertEqual(
torch.exp(x),
torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(
torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
)
def override_add_cuda() -> None:
# Example 4: simulate a hardware bug, where the adder is always off by 1
buggy_add_code_string = """
template <typename T> T buggy_add(T a, T b){
return a + b + T(1);
}
"""
jitted_add = _create_jit_fn(buggy_add_code_string)
CALLED = [False]
def buggy_add(*args, **kwargs):
CALLED[0] = True
return jitted_add(*args, **kwargs)
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")
x_cpu = torch.rand([3, 3], device="cpu")
y_cpu = torch.rand([3], device="cpu")
x_cuda = x_cpu.cuda()
y_cuda = y_cpu.cuda()
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
if torch.cuda.is_available() and not TEST_WITH_ROCM:
override_where_cuda()
override_gelu_cuda()
override_exp_cuda()
override_add_cuda()
def test_extend_library_with_dispatch_key_arg(self):
def my_sum(*args, **kwargs):
return args[0].clone()
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
with self.assertRaisesRegex(
RuntimeError, "inconsistent with the dispatch key"
):
my_lib1.impl("sum", my_sum, "Conjugate")
my_lib1.impl("aten::sum", my_sum)
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
def test_create_new_library(self) -> None:
with _scoped_library(self.test_ns, "DEF") as my_lib1:
my_lib1.define("sum(Tensor self) -> Tensor")
# Example 1
@torch.library.impl(my_lib1, "sum", "CPU")
def my_sum(*args, **kwargs):
return args[0].clone()
x = torch.tensor([1, 2])
op = getattr(torch.ops, self.test_ns).sum
self.assertEqual(op(x), x)
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
# Example 2
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
def my_sum_zt(*args, **kwargs):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
else:
return args[0].clone()
y = torch._efficientzerotensor(3)
self.assertTrue(op(y)._is_zerotensor())
self.assertEqual(op(x), x)
def test_create_new_library_fragment_no_existing(self):
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
my_lib.define("sum2(Tensor self) -> Tensor")
@torch.library.impl(my_lib, "sum2", "CPU")
def my_sum(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
def test_create_new_library_fragment_with_existing(self):
with _scoped_library(self.test_ns, "DEF") as my_lib1:
# Create a fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
my_lib2.define("sum4(Tensor self) -> Tensor")
@torch.library.impl(my_lib2, "sum4", "CPU")
def my_sum4(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
# Create another fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
my_lib3.define("sum3(Tensor self) -> Tensor")
@torch.library.impl(my_lib3, "sum3", "CPU")
def my_sum3(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
def test_alias_analysis(self):
def test_helper(alias_analysis=""):
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
called = [0]
@torch.library.define(
my_lib1, "_op() -> None", alias_analysis=alias_analysis
)
def _op(*args, **kwargs):
called[0] += 1
@torch.jit.script
def _test():
torch.ops._test_python_registration._op()
assert "_test_python_registration::_op" in str(_test.graph)
with self.assertRaises(AssertionError):
test_helper("") # alias_analysis="FROM_SCHEMA"
test_helper("CONSERVATIVE")
def test_error_for_unsupported_ns_or_kind(self) -> None:
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
my_lib1 = Library("myns", "BLA") # noqa: TOR901
for kind in ("DEF", "FRAGMENT"):
with self.assertRaisesRegex(ValueError, "reserved namespace"):
my_lib1 = Library("prim", kind) # noqa: TOR901
def test_returning_symint(self) -> None:
shape_env = ShapeEnv()
fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
s0, s1 = ft.shape
with _scoped_library(self.test_ns, "DEF") as tlib:
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
def sqsum(a: SymInt, b: SymInt):
return a * a + b * b
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
out_val = shape_env.evaluate_expr(out.node.expr)
self.assertEqual(out_val, 13)
def test_register_functional_op_error_cases(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
register_functional_op(lib, "abs", torch.ops.aten.abs_)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
schemas = [
"foo(Tensor x, Tensor(a!)[] y) -> ()",
"foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)",
"foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))",
]
for schema in schemas:
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(schema)
with self.assertRaisesRegex(RuntimeError, "NYI"):
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default,
)
def _check_is_functional_variant(self, mutable_op, functional_op, args):
# functional op should not mutate
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
functional_result = functional_op(*cloned_args)
self.assertEqual(cloned_args, args)
# check functional_result includes mutable_result
mutable_result = mutable_op(*cloned_args)
if mutable_result is None:
flat_mutable_result = []
else:
flat_mutable_result = pytree.tree_leaves(mutable_result)
flat_functional_result = pytree.tree_leaves(functional_result)
assert len(flat_functional_result) > len(flat_mutable_result)
self.assertEqual(
flat_functional_result[: len(flat_mutable_result)], flat_mutable_result
)
# check rest of functional_result is the mutated args
mutated_args = [
maybe_mutated_arg
for maybe_mutated_arg, arg in zip(cloned_args, args)
if not (
maybe_mutated_arg is not None
and arg is not None
and torch.allclose(maybe_mutated_arg, arg)
)
]
self.assertEqual(
flat_functional_result[len(flat_mutable_result) :], mutated_args
)
# check that functionalization kernel was indeed registered
def fn(*args):
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
mutable_op(*cloned_args)
return cloned_args
gm = make_fx(torch.func.functionalize(fn))(*args)
has_functional_op = False
for node in gm.graph.nodes:
self.assertFalse(node.target is mutable_op)
if node.target is functional_op:
has_functional_op = True
self.assertTrue(has_functional_op)
def test_register_functional_op_no_returns(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()")
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_functional_op_with_optional(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
z.fill_(2.71)
if w is not None:
w.fill_(1.618)
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, None),
)
def test_register_functional_op_one_return(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
z.fill_(0.99)
return x.clone()
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_functional_op_multiple_returns(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
return x.clone(), z.clone()
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_fallthrough(self):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# dtype for mm should be float32 since we registered a fallthrough
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
# ops that don't have a fallthrough registered should not be affected
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# default behavior should have been restored
self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
y = x * x
saved_x = y.grad_fn._saved_self
grad_y = LoggingTensor(torch.tensor([1.0]))
log_input("grad_y", grad_y)
(g,) = torch.autograd.grad((y,), (x,), (grad_y,))
self.assertEqual(g.elem, torch.tensor([6.0]))
with torch.no_grad():
self.assertEqual(saved_x, x)
self.assertEqual(saved_x._version, x._version)
x.add_(2)
self.assertEqual(saved_x, x)
# TODO: figure out why broken
# self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
)
def test_out(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.zeros(1))
log_input("x", x)
log_input("y", y)
torch.abs(x, out=y)
self.assertEqual(y.elem, torch.ones(1))
# TODO: arguably this shouldn't pass and we should complain
# that out isn't a kwarg
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('y')
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
)
def test_kwarg_only(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.ones(1, 1))
z = LoggingTensor(torch.ones(1))
log_input("x", x)
log_input("y", y)
log_input("z", z)
torch.addmv(x, y, z)
torch.addmv(x, y, z, beta=1)
torch.addmv(x, y, z, beta=2)
torch.addmv(x, y, z, alpha=2)
torch.addmv(x, y, z, beta=2, alpha=2)
# The expectation is that beta/alpha don't show up when they're
# defaulted. This is even if the user explicitly specified it.
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1, 1] = input('y')
$2: f32[1] = input('z')
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
)
def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
log_input("x", x)
torch.ops.aten._foobar(x)
torch.ops.aten._foobar(x, False)
torch.ops.aten._foobar(x, arg3=False)
torch.ops.aten._foobar(x, False, arg3=False)
# What we are testing here is that we omit arg2
# if it is defaulted, even if a kwarg is set
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten._foobar.default($0)
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
)
def test_produce_real_type(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
x.to(dtype=torch.double) # non-optional dtype
torch.cumprod(x, 0, dtype=torch.double) # optional dtype
x[:, 1].contiguous(
memory_format=torch.contiguous_format
) # optional memory format
# There doesn't appear to be any layout signatures which are
# triggerable using tensor subclasses (need to use a mode)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
)
def test_optional_tensor_list(self) -> None:
def weird(xs):
print("woof")
return torch.empty(())
with _scoped_library("my_lib", "DEF") as my_lib:
my_lib.define("weird(Tensor?[] self) -> Tensor")
my_lib.impl("weird", weird, "CPU")
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
torch.ops.my_lib.weird.default([None, x])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
)
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
for list_type in (list, tuple):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket == torch.ops.aten.split:
with no_dispatch():
return list_type(torch.split(*args))
else:
raise AssertionError(f"unrecognized func: {func}")
self.assertEqual(
torch.split(A(torch.tensor([0, 1])), 2),
torch.split(torch.tensor([0, 1]), 2),
)
def test_invalid_ret(self) -> None:
# test invalid return gets reasonable error message
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return "arf"
# Wobbles depending on NDEBUG mode of pybind11
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).neg(),
)
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).detach(),
)
def test_detach_appears_twice_when_called_once(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
x.detach()
# FIXME: We actually want this to emit a single detach. However,
# it currently emits two, for reasons unclear to us. Leaving
# this test here to make sure we don't regress even further (it
# would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)""",
)
def test_storage(self) -> None:
# For now, just make sure it doesn't crash. Ideally, we should
# return some virtual storage that is safe to work with
x = LoggingTensor(torch.ones(1))
storage = x.untyped_storage()
self.assertRaises(RuntimeError, lambda: storage.data_ptr())
def test_make_wrapper_subclass_noalloc(self) -> None:
# This is ludicrously big (8TB) and this should pass because wrapper
# subclasses don't allocate
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
def test_version(self) -> None:
x = LoggingTensor(torch.ones(1))
prev_vc = x._version
x.detach().add_(2)
cur_vc = x._version
self.assertNotEqual(prev_vc, cur_vc)
x.data.add_(2)
self.assertEqual(cur_vc, x._version)
def test_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
# The big tests for code coverage are test_precedence_semantics in
# test_overrides.py; this is just to make sure it is wired up at all
# correctly for __torch_dispatch__
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorB
self.assertRaises(
ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
)
def test_format(self) -> None:
x = LoggingTensor(torch.ones(1))
s1 = str(x)
s2 = repr(x)
s3 = f"{x}"
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
self.assertEqual(s1, s2)
self.assertEqual(s1, s3)
def test_custom_autograd(self) -> None:
escape = [None]
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x**2
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, grad_output):
assert isinstance(grad_output, LoggingTensor)
(x,) = ctx.saved_tensors
assert isinstance(x, LoggingTensor)
escape[0] = x
return grad_output * 2 * x
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1), requires_grad=True)
log_input("x", x)
x.grad = LoggingTensor(torch.zeros(1))
log_input("x.grad", x.grad)
y = Square.apply(x)
grad_output = LoggingTensor(torch.ones(1))
log_input("grad_output", grad_output)
y.backward(grad_output)
with torch.no_grad():
self.assertEqual(escape[0], x)
self.assertEqual(escape[0]._version, x._version)
# TODO: figure out why x.requires_grad = False doesn't
# trigger an error for LoggingTensor
x.add_(2)
self.assertEqual(escape[0], x)
# TODO: figure out why this is broken
# self.assertEqual(escape[0]._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
)
def test_subclass_creation(self):
# Make sure these statements runs without error
# In particular checking that when internal detach returns
# subclasses, these are cleanly overwritten.
class Foo(torch.Tensor):
pass
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(RuntimeError, err_msg):
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
with self.assertRaisesRegex(RuntimeError, err_msg):
Foo(LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
def test_new_ones(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
def test_like(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
for f in ["empty", "ones", "rand", "randn", "zeros"]:
f_name = f + "_like"
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
def test_make_fx_with_subclass(self) -> None:
def f(x, y):
# Returns (TwoTensor, Tensor)
return x * y, y + y
x_a = torch.zeros(4)
x_b = torch.zeros(4)
y = torch.ones(4)
# make_fx() is not responsible for unwrapping tensor subclass inputs,
# so we do it manually here.
# Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
# convention as f(*args). Unwrapping tensor subclass inputs can potentially change
# the number of input args to the graph, breaking that assumption
def f_to_trace(x_a, x_b, y):
x = TwoTensor(x_a, x_b)
out1, out2 = f(x, y)
out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
self.assertExpectedInline(
fx_g.code,
"""\
def forward(self, x_a_1, x_b_1, y_1):
mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None
mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None
add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None
return (mul, mul_1, add)
""",
)
# See https://github.com/pytorch/pytorch/issues/117794
def test_return_and_correct_aliasing_gives_correct_stride(self):
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
x = torch.randn(2, 2)
# slicing should result in the same stride for TwoTensor as a dense tensor would give
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
class WrapperTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
strides=elem.stride(),
storage_offset=elem.storage_offset(),
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError("NYI")
# non-contiguous strides, non-zero storage offset
x = torch.randn(4, 6).t().diagonal(offset=2)
y = WrapperTensor(x)
self.assertEqual(y.size(), x.size())
self.assertEqual(y.stride(), x.stride())
self.assertEqual(y.storage_offset(), x.storage_offset())
def test_wrapper_subclass_serializes(self) -> None:
with tempfile.TemporaryFile() as f:
# purposefully use int64 to test non-default dtype
x = LoggingTensor(torch.randperm(3))
torch.save(x, f)
f.seek(0)
with torch.serialization.safe_globals([LoggingTensor]):
x_loaded = torch.load(f)
self.assertTrue(type(x_loaded) is type(x))
self.assertEqual(x, x_loaded)
self.assertEqual(x.elem, x_loaded.elem)
self.assertFalse(x is x_loaded)
def test_deepcopy_wrapper_subclass(self) -> None:
# purposefully use int64 to test non-default dtype
x = LoggingTensor(torch.randperm(3))
x_copy = deepcopy(x)
self.assertTrue(type(x_copy) is type(x))
self.assertEqual(x, x_copy)
self.assertEqual(x.elem, x_copy.elem)
self.assertFalse(x is x_copy)
def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(
self,
) -> None:
class MyWrapperTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
strides=elem.stride(),
storage_offset=elem.storage_offset(),
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket.__name__ == "clone":
# Return a plain tensor from clone().
return args[0].elem.clone()
raise RuntimeError("NYI")
# NB: The default Tensor.__torch_function__ implementation called for deepcopy
# disables __torch_function__ by the time we get to clone(), so there is no need to
# explicitly disable __torch_function__ for this subclass.
x = MyWrapperTensor(torch.randn(3))
with self.assertRaisesRegex(
RuntimeError,
"for which cloning returns another instance of the same subclass",
):
x_copy = deepcopy(x)
def test_deepcopy_non_wrapper_subclass(self) -> None:
# Ensure correct error is thrown for common error cases.
class SubTensorError1(torch.Tensor):
# Default implementation of new_empty() returns a plain tensor.
pass
class SubTensorError2(torch.Tensor):
# new_empty() incorrectly returns a different type (i.e. a plain tensor).
def new_empty(self, shape):
return torch.Tensor(shape)
for error_cls in [SubTensorError1, SubTensorError2]:
x = error_cls(3)
with self.assertRaisesRegex(
RuntimeError,
"for which that function returns another instance of the same subclass",
):
x_copy = deepcopy(x)
# Ensure a correctly implemented new_empty() causes deepcopy() to work.
class SubTensorSuccess(torch.Tensor):
def new_empty(self, shape):
return type(self)(shape)
x = SubTensorSuccess(3)
x_copy = deepcopy(x)
self.assertIs(type(x_copy), type(x))
def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
class ExtraKeysTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# NB: only the non-kwarg overload of _make_wrapper_subclass supports
# extra dispatch keys. We probably want to unify the two APIs
# in the future.
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
elem.stride(),
elem.storage_offset(),
torch.contiguous_format,
elem.dtype,
elem.layout,
elem.device,
False,
False,
None,
False,
False,
DispatchKeySet(DispatchKey.NestedTensor),
)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
x = ExtraKeysTensor(torch.randn(3))
self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
self.assertFalse(
torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor)
)
def test_wrapper_subclass_multiprocessing_preserves_dtype(self):
# a and b have dtype of int64, which is purposefully different from the default
# assumed by _make_wrapper_subclass().
a = torch.randperm(5)
b = torch.randperm(5)
data = TwoTensor(a, b)
expected_dtype = data.dtype
loader = torch.utils.data.DataLoader(
[data, data],
batch_size=2,
num_workers=2,
collate_fn=_identity,
)
for batch in loader:
self.assertEqual(batch[0].dtype, expected_dtype)
def test_index_put_where_only_index_is_subclass(self) -> None:
called_funcs = []
class MyTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called_funcs.append(func)
return MyTensor(torch.tensor(3))
x = torch.randn(3, 3)
idxs = (MyTensor(torch.tensor(0)),)
v = torch.randn(1)
res = x.index_put_(idxs, v)
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_torch_dispatch_mode_basic(self) -> None:
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
torch.empty([])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
)
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
x + y
self.assertExpectedInline(
"\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)"""
)
def test_nested_push_logging_tensor_mode(self):
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
with LoggingTensorMode():
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
def test_capture_logs_with_torch_dispatch_mode(self):
x = torch.randn([])
y = torch.randn([])
with capture_logs_with_logging_tensor_mode() as logs:
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
x = torch.randn([])
y = torch.randn([])
with capture_logs_with_logging_tensor_mode() as logs1:
with capture_logs_with_logging_tensor_mode() as logs2:
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs2),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
self.assertEqual(logs1, logs2)
def test_torch_dispatch_mode_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with AMode():
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with BMode():
func(*args, **kwargs)
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA
class BMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorB
a = A(torch.empty(1))
b = B(torch.empty(1))
with self.assertRaises(ErrorA):
a + a
with self.assertRaises(ErrorB):
a + b
# B has precedence over A due to the subclass relationship yet
# modes take precedence over arguments
with self.assertRaises(ErrorA):
with AMode():
b + b
with self.assertRaises(ErrorB):
with BMode():
a + a
with self.assertRaises(ErrorB):
with BMode():
a + b
def test_mode_with_make_subclass(self):
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
class BasicMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
x = torch.randn(3)
with BasicMode():
y = SubTensor(x)
self.assertIsInstance(y, SubTensor)
def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
with capture_logs(is_mode=True) as logs1:
with LoggingTensorMode():
torch.ones([2, 3])
with no_dispatch():
torch.ones([2, 3])
with capture_logs(is_mode=True) as logs2:
with LoggingTensorMode():
torch.ones([2, 3])
self.assertEqual(logs1, logs2)
def test_shallow_copy_and_detach(self) -> None:
seen = set()
test_case = self
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
tree_map_only(
torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs)
)
if kwargs is None:
kwargs = {}
r = func(*args, **kwargs)
tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
return r
with TestMode():
x = torch.randn(3, requires_grad=True)
loss = (x * x).sum()
loss.backward()
def test_exception_handling(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func.__name__ == "randn.default":
raise RuntimeError
return A(torch.zeros(()))
with AMode():
try:
torch.randn(())
except RuntimeError:
pass
self.assertTrue(isinstance(torch.zeros(()), A))
def test_with_mode_created_separately(self):
class ErrorA(RuntimeError):
pass
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA
x = A()
with self.assertRaises(ErrorA):
with x:
torch.empty([])
def test_with_nested_modes(self):
class ErrorA(RuntimeError):
def __init__(self, msg):
super().__init__(msg)
class A(TorchDispatchMode):
def __init__(self, msg):
self.msg = msg
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)
with self.assertRaisesRegex(ErrorA, "layer2"):
with A("layer1"):
with A("layer2"):
torch.empty([])
def test_make_subclass_with_modes(self):
class ModeTensor(torch.Tensor):
def __new__(cls, elem, mode):
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
r.elem = elem
r.mode = mode
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise NotImplementedError("Shouldn't be here")
class Mode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, ModeTensor):
return e.elem
else:
return e
def wrap(t):
if isinstance(t, torch.Tensor):
return ModeTensor(t, self)
else:
return t
return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
class BasicMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
x = torch.tensor(4.0)
with Mode():
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
with Mode():
with BasicMode(): # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
assert self.assertRaisesRegex(
RuntimeError,
"subclass Mode but.* associated to a python object of type Mode",
)
def test_notimplemented_mode(self):
sub_count = 0
class PoliteMode(TorchDispatchMode):
def __init__(self) -> None:
self.pre_count = 0
self.post_count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.pre_count += 1
if any(t is not torch.Tensor for t in types):
return NotImplemented
self.post_count += 1
return func(*args, **kwargs)
class SubTensor(torch.Tensor):
def __new__(cls, elem):
r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
nonlocal sub_count
sub_count += 1
def unwrap(t):
if isinstance(t, SubTensor):
return t.elem
else:
return t
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
a = SubTensor(torch.randn(2))
with PoliteMode() as mode:
a.abs()
self.assertEqual(mode.pre_count, 2)
self.assertEqual(mode.post_count, 1)
self.assertEqual(sub_count, 1)
# make sure this doesn't error
with PoliteMode():
with PoliteMode():
a.abs()
def test_nesting_same_mode(self):
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode() as reenabled:
with reenabled:
torch.empty([])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
)
def test_error_using_class_method_on_mode(self):
class A(TorchDispatchMode):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return func(args, kwargs)
x = torch.tensor(5.0)
with self.assertRaisesRegex(
RuntimeError, "classmethod is not supported, please make it a plain method"
):
with A():
x + x
def test_get_cur_mode(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(_get_current_dispatch_mode(), None)
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode(), mode1)
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode(), mode2)
def test_get_mode_stack(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(_get_current_dispatch_mode_stack(), [])
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
def test_all_same_mode(self):
x = LoggingTensorMode()
y = LoggingTensorMode()
self.assertTrue(all_same_mode([x, x, x]))
self.assertFalse(all_same_mode([x, None]))
self.assertFalse(all_same_mode([x, y]))
def test_mode_detection(self):
class InfraMode(TorchDispatchMode):
@classmethod
def is_infra_mode(cls):
return True
class NonInfraMode(TorchDispatchMode):
pass
with InfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
with NonInfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
with InfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(
is_in_torch_dispatch_mode(include_infra_modes=False)
)
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
self.assertTrue(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
self.assertFalse(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(
RuntimeError, "is not supported for tensor subclasses."
):
x.tolist()
with self.assertRaisesRegex(
RuntimeError, "is not supported for tensor subclasses."
):
x.numpy()
with self.assertRaises(AssertionError):
self.assertEqual(x, None)
def test_record_stream(self) -> None:
class TestMode(TorchDispatchMode):
def __init__(self, testcase):
self.testcase = testcase
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.testcase.assertEqual(func.name(), "aten::record_stream")
self.testcase.assertIsInstance(args[0], torch.Tensor)
self.testcase.assertIsInstance(args[1], torch.Stream)
self.testcase.assertEqual(args[1].stream_id, 1)
self.testcase.assertEqual(args[1].device_index, 2)
self.testcase.assertEqual(args[1].device_type, 3)
t = torch.tensor(5.0)
s = torch.Stream(stream_id=1, device_index=2, device_type=3)
with TestMode(self):
t.record_stream(s)
def test_return_stream(self) -> None:
with _scoped_library("test_return_stream", "DEF") as l_def:
l_def.define("return_stream(Tensor self) -> Stream")
with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
l_impl.impl(
"return_stream",
lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2),
)
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return torch.Stream(stream_id=1, device_index=2, device_type=3)
t = torch.tensor(5.0)
s = torch.ops.test_return_stream.return_stream(t)
self.assertIsInstance(s, torch.Stream)
self.assertEqual(s.stream_id, 0)
self.assertEqual(s.device_index, 1)
self.assertEqual(s.device_type, 2)
with TestMode():
s = torch.ops.test_return_stream.return_stream(t)
self.assertIsInstance(s, torch.Stream)
self.assertEqual(s.stream_id, 1)
self.assertEqual(s.device_index, 2)
self.assertEqual(s.device_type, 3)
def test_subclass_autograd_device_check(self) -> None:
class NonWrapperSubclass(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# Wrong device here!
r = torch.Tensor._make_subclass(
cls, elem.to("meta"), elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, NonWrapperSubclass) else e
def wrap(e):
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
logging.getLogger("NonWrapperSubclass").info(
f"{func.__module__}.{func.__name__}", # noqa: G004
args,
kwargs,
rs,
)
return rs
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
y = torch.randn(2, requires_grad=True)
z = x * y
self.assertIsInstance(z, NonWrapperSubclass)
z.sum().backward(torch.tensor(1))
self.assertEqual(x.grad, y)
self.assertEqual(y.grad, x)
def test_none_wrapping(self):
# A Tensor subclass that returns None when doing add
# See LoggingTensor above for more details on the subclass
class SubclassWithNone(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, SubclassWithNone) else e
def wrap(e):
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
if func.overloadpacket.__name__ == "add":
return None
else:
return rs
x = SubclassWithNone(torch.rand(2))
# Make sure both run without error
self.assertIsInstance(x * 2, SubclassWithNone)
self.assertIsNone(x + 2)
x.requires_grad_()
out = x.acos().sum()
# The backward of acos does add then rsqrt so here we make sure that the
# undefined Tensor generated by the user code is nicely handled.
# If acos formula changes in the future, this can be replaced by any other
# function that does add then something in the backward in a composite way
with self.assertRaisesRegex(RuntimeError, "but got None"):
out.backward()
def test_storage_can_be_converted_to_python_object(self):
s = torch.Storage()
z = LoggingTensor(torch.empty([]))
z.set_(s)
def test_autograd_in_attr(self):
# We want the wrapped Tensor to require gradients!
true_t = torch.rand(2, requires_grad=True)
t = LoggingTensorReentrant(true_t)
out = t + 2
self.assertFalse(out.requires_grad)
self.assertIsNone(out.grad_fn)
self.assertTrue(out.elem.requires_grad)
self.assertIsNotNone(out.elem.grad_fn)
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
out.sum().backward()
out.elem.sum().backward()
self.assertIsNone(t.grad)
self.assertIsNotNone(t.elem.grad)
def test_dispatch_super_call(self):
called = []
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
return super().__torch_dispatch__(func, types, args, kwargs)
x = torch.randn(2)
y = torch.randn(2)
self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
self.assertEqual(called, [torch.ops.aten.add.Tensor])
def test_dispatch_super_call_list_arg(self):
called = []
class SubTensorWithListArg(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
return super().__torch_dispatch__(func, types, list(args), kwargs)
x = torch.randn(2)
self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
self.assertEqual(called, [torch.ops.aten.neg.default])
def test_dispatch_super_dont_autograd(self):
called = []
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
# This argument still requires grad because it was passed
# through directly...
self.assertTrue(args[0].requires_grad)
r = super().__torch_dispatch__(func, types, args, kwargs)
# But the output better not require grad, because that means
# you did autograd again in torch dispatch (oops)
self.assertFalse(r.requires_grad)
return r
x = SubTensor(torch.randn(2, requires_grad=True))
x.neg()
self.assertEqual(called, [torch.ops.aten.neg.default])
def test_set_data(self):
called = 0
class SubTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
nonlocal called
called += 1
return super().__torch_dispatch__(func, types, args, kwargs)
x = SubTensor(torch.empty(2))
x.data
self.assertEqual(called, 1)
x.data = torch.empty(2)
self.assertEqual(called, 1)
x.data
self.assertEqual(called, 2)
self.assertIs(type(x), SubTensor)
x.set_(torch.empty(2))
self.assertEqual(called, 3)
x.data
self.assertEqual(called, 4)
self.assertIs(type(x), SubTensor)
def test_construct_int_tensor(self):
class SubTensor(torch.Tensor):
pass
# should not fail
SubTensor(torch.zeros(2, dtype=torch.int))
def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that!
class MySubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_subclass(cls, elem)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch():
return func(*args, **kwargs)
x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
y = x.conj()
# Details of the bug that this tests for:
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
# There are a few calls to the dispatcher that are going to happen here:
# - call_exp: User calling exp on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - AutogradCPU: no input requires grad, so does nothing and redispatch
# - Conjugate: no special implementation for exp: use the fallback that
# first clone the Tensor (to materialize the conj) then redispatch
# - call_clone: conjugate fallback calling clone on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
# - Conjugate: special implementation for clone: just skip this key
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
# actually calls into the dispatcher again but since we disable both our keys
# before, not detailed here)
# - exit Python: restore the TLS and exit
# - exit Conjugate: nothing was inplace so just exit
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
# - More steps....
y.exp()
@staticmethod
def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
if use_wrapper_subclass:
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
else:
return torch.Tensor._make_subclass(cls, data, True, **kwargs)
def test_is_contiguous_slow_path(self):
data = torch.randn(3, 3)
contiguous_data = data.clone()
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
for use_wrapper_subclass in [True, False]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.is_contiguous()
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), True)
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
err_msg = "Multiple dispatch failed for"
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), False)
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
def test_fancy_strides(self):
calls = []
class ExampleTensor(torch.Tensor):
@staticmethod
def __new__(cls, data):
return TestPythonDispatch.subclass_helper(
cls, data, False, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in [
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.stride.default,
]:
calls.append((func, list(args)[1:]))
return None
with no_dispatch():
return func(*args, **kwargs)
e = ExampleTensor(torch.randn(2, 2))
self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(
calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])]
)
calls.clear()
self.assertFalse(
torch.ops.aten.is_strides_like_format.default(e, torch.channels_last)
)
self.assertEqual(
calls,
[(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])],
)
calls.clear()
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
self.assertEqual(
calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])]
)
def test_device_slowpath(self):
for use_wrapper_subclass in [True]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.device:
return torch.device("meta")
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.device:
return torch.device("meta")
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
with self.assertRaisesRegex(TypeError, err_msg):
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
e.device()
ten = torch.rand([1])
e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
self.assertEqual(e.device.type, "meta")
self.assertEqual(ten.type_as(e).device.type, "meta")
e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
self.assertEqual(e.device.type, "meta")
self.assertEqual(ten.type_as(e).device.type, "meta")
def test_dim_slowpath(self):
data = torch.randn(3, 3)
for use_wrapper_subclass in [True, False]:
class DimNotImplementedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class DimImplementedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.dim()
t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(t.dim(), 2)
def test_maybe_tuple_bug(self):
class T(torch.Tensor):
@classmethod
def __torch_function__(cls, *args, **kwargs):
pass
a = torch.rand(3)
a[[T(), T()]]
def test_standard_is_not_subclass(self):
# https://github.com/pytorch/pytorch/issues/79079
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
def test_sym_sizes_strides_slow_path(self):
class TestTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, (0,), dispatch_sizes_strides_policy="sizes"
)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func in (
torch.ops.aten.sym_size.default,
torch.ops.aten.sym_stride.default,
):
from torch._dynamo.source import ConstantSource
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
)
shape_env = ShapeEnv()
si = shape_env.create_symintnode(
shape_env.create_symbol(
123,
source=ConstantSource("abc"),
dynamic_dim=DimDynamic.DUCK,
constraint_dim=None,
),
hint=123,
)
return (si,)
t = TestTensor()
si = t.size()[0]
self.assertIsInstance(si, torch.SymInt)
si = t.stride()[0]
self.assertIsInstance(si, torch.SymInt)
def test_strides_slow_path(self):
for use_wrapper_subclass in [True, False]:
class StridesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class StridesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_stride.default:
return (4, 2)
return NotImplemented
class StridesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_stride.default:
return None
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.stride()
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.stride(), (4, 2))
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
self.assertEqual(e.stride(), (2, 1))
def test_sizes_slow_path(self):
for use_wrapper_subclass in [True, False]:
data = torch.randn(6, 2)
class SizesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
return NotImplemented
class SizesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.sym_size:
return (5, 3)
return NotImplemented
class SizesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.sym_size:
return None
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.size()
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.size(), (5, 3))
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.size(), (4, 2))
def test_custom_size_policy_dynamic_shapes(self):
data = torch.randn(6, 2)
class CustomSizeDynamicShapesTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
inner.size(),
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
"sizes",
)
def __init__(self, inner):
self.inner = inner
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_size.default:
return args[0].inner.shape
if func == torch.ops.aten.sym_stride.default:
return args[0].inner.shape
return NotImplemented
x = torch.ones(2, 2)
def trace_fn(x):
x_wrapper = CustomSizeDynamicShapesTensor(x)
return x_wrapper.size(), x_wrapper.stride()
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
self.assertExpectedInline(
fx_g.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""",
)
def test_data_ptr_respects_numel_slow_path(self):
data = torch.randn(6, 2)
class NumelDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.numel:
numel_called[0] = True
return None
return NotImplemented
for use_wrapper_subclass in (False, True):
numel_called = [False]
e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
e.data_ptr()
self.assertTrue(numel_called[0])
def test_layout_slow_path(self):
for use_wrapper_subclass in [True, False]:
data = torch.randn(6, 2)
class LayoutNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class LayoutCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.layout:
return torch.sparse_csr
return NotImplemented
class LayoutDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.layout:
return data.layout
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.layout
e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.layout, torch.sparse_csr)
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.layout, torch.strided)
class TestPythonDispatcher(TestCase):
def test_basic(self):
x = torch.randn(2, requires_grad=True)
r = torch._C._EnablePythonDispatcher()
torch.add(x, x)
def test_lstsq(self):
a = torch.randn(4, 3)
b = torch.rand(4, 3)
expected_shape = torch.linalg.lstsq(a, b).solution.shape
r = torch._C._EnablePythonDispatcher()
python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
self.assertEqual(expected_shape, python_disp_shape)
class TestWrapperSubclassAliasing(TestCase):
def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
def to_subclass(t: torch.Tensor):
return TwoTensor(t, t.clone())
result_ref = op(*args, **kwargs)
args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
result_test = op(*args_subclass, **kwargs_subclass)
args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
args_ref_flat_tensors = [
x for x in args_ref_flat if isinstance(x, torch.Tensor)
]
args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
args_test_flat_tensors = [
x for x in args_test_flat if isinstance(x, torch.Tensor)
]
result_ref_flat = pytree.tree_leaves(result_ref)
result_ref_flat_tensors = [
x for x in result_ref_flat if isinstance(x, torch.Tensor)
]
result_test_flat = pytree.tree_leaves(result_test)
result_test_flat_tensors = [
x for x in result_test_flat if isinstance(x, torch.Tensor)
]
for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
out_is_inpt = o_ref is a_ref
if out_is_inpt:
self.assertTrue(o_test is a_test)
out_aliases_inpt = StorageWeakRef(
o_ref.untyped_storage()
) == StorageWeakRef(a_ref.untyped_storage())
if out_aliases_inpt:
self.assertTrue(
StorageWeakRef(o_test.untyped_storage())
== StorageWeakRef(a_test.untyped_storage())
)
else:
self.assertFalse(
StorageWeakRef(o_test.untyped_storage())
== StorageWeakRef(a_test.untyped_storage())
)
# This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
# a util for wrapper subclasses to promise correct aliasing behavior.
# It's probably overkill to test every OpInfo,
# so I picked a sampling of ops with representative schemas.
@ops(
[
op
for op in op_db
if op.name
in [
"mul", # out-of-place
"cat", # out-of-place (TensorList input)
"index", # out-of-place (Optional TensorList input)
"mul_", # inplace
"view", # view
"t_", # inplace-view
"split", # view (multi-return)
"native_batch_norm", # mutable op (returns outputs and mutates some inputs)
]
],
allowed_dtypes=(torch.float,),
)
def test_wrapper_subclass_aliasing(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
@ops(custom_op_db, allowed_dtypes=(torch.float,))
def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
def test_wrapper_subclass_aliasing_conv2d(self, device):
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
kwargs = {}
# conv2d has a default arg 'int[2] strides=0',
# which torchscript expands into 'int[2] strides=[0, 0]'
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(
torch.ops.aten.conv2d.default, args, kwargs
)
def test_wrapper_subclass_aliasing_out_op(self, device):
# Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
args = (torch.ones(4), torch.ones(4))
kwargs = {"out": torch.empty(4)}
self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
if __name__ == "__main__":
run_tests()