blob: 2bc7101c555740b4effb67bd06c2ed3968022adb [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import itertools
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.wrap import wrap
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.nested._internal.nested_tensor import (
jagged_from_list,
jagged_from_tensor_and_lengths,
nested_view_from_values_offsets,
NestedTensor,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
subtest,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
def traceable_subclass(c):
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
class BaseTorchFunction(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
class MockSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class AttrSubclass(torch.Tensor):
x: int = 10
size: int = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class DummyNDim(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.Tensor.ndim.__get__:
return 10
return super().__torch_function__(func, types, args, kwargs)
class WrapperSubclass:
def __init__(self, tensor):
self.tensor = tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args)
kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs)
return func(*args, **kwargs)
class SigmoidToExpSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.Tensor.sigmoid:
return super().__torch_function__(torch.Tensor.exp, types, args, kwargs)
return super().__torch_function__(func, types, args, kwargs)
# Wrapper subclass with two inner tensors: data and scale
# data has same shape as outer, and scale has single dim size
class ScaledTensor(torch.Tensor):
def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
*,
constant: int = 0,
):
return torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=data.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0):
self._data = data
self._scale = scale
self._constant = constant
def __tensor_flatten__(self):
ctx = {"_constant": self._constant}
return ["_data", "_scale"], ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 2
return ScaledTensor(
inner_tensors["_data"],
inner_tensors["_scale"],
constant=metadata["_constant"],
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
scaled_tensor = args[0]
out = func(scaled_tensor._data, *args[1:], **kwargs)
return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
def __repr__(self):
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
def func(a):
return a.sin()
class EagerRecordGraphAndInputs:
def __init__(self):
self.graphs = []
self.example_inputs = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
self.graphs.append(gm)
self.example_inputs.append(example_inputs)
return gm
GLOBAL_TEST_SUBCLASSES = {
MockSubclass,
DummyNDim,
SigmoidToExpSubclass,
BaseTorchFunction,
}
# Returns True if the function recompiles between inputs1 and inputs2 with the
# specified dynamic setting.
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
compile_count = [0]
def counter(gm, example_inputs):
compile_count[0] += 1
return gm
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
compiled_f(*inputs1)
compiled_f(*inputs2)
return compile_count[0] > 1
class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch(
"traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
def test_no_call_to_new(self):
class BadNewTorchFunction(torch.Tensor):
def __new__(cls, *args, **kwargs):
raise RuntimeError("Oops!")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {BadNewTorchFunction}
):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1)
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)
def test_base_torch_function_tracing(self):
def fn(x):
return torch.add(x, 1)
input = torch.ones(2, 2).as_subclass(BaseTorchFunction)
out = fn(input)
out_opt = compile_full_eager(fn)(input)
self.assertIsInstance(out, BaseTorchFunction)
self.assertEqual(out, out_opt)
def test_torch_function_state_graph_break(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch._dynamo.graph_break()
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_nested(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunctionSubclass():
x = x + 1
# Should reset to the outer state (disabled) after exiting ctx manager
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_tracing(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch.add(x, 1.0)
input = torch.ones(2, 2)
res = fn(input)
def test_torch_function_state_guards(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
torch.add(x, 1.0)
input = torch.ones(2, 2)
with torch._C.DisableTorchFunctionSubclass():
res = fn(input)
res = fn(input)
self.assertEqual(cnt.frame_count, 2)
def test_return_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return MockSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_as_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1.0).as_subclass(MockSubclass)
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_local_subclass(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
@parametrize(
"comparison",
[
subtest(isinstance, "isinstance"),
subtest(lambda instance, type_: type(instance) == type_, "equality"),
subtest(lambda instance, type_: type(instance) is type_, "identity"),
],
)
@parametrize(
"input_type",
[
subtest(torch.Tensor, "tensor"),
subtest(DummyNDim, "subclass"),
],
)
def test_type_check(self, comparison, input_type):
with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
def fn(x):
if comparison(x, DummyNDim):
return torch.ones(1, 1)
else:
return torch.zeros(2, 2)
input = torch.ones(2, 2).as_subclass(input_type)
exp_res = fn(input)
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
self.assertEqual(exp_res, act_res)
def test_torch_function_call_on_method(self):
x = torch.ones(2, 2)
y = torch.ones(2, 2)
z = torch.ones(2, 2)
wrapped = x.as_subclass(SigmoidToExpSubclass)
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
def fn(w):
return w.sigmoid()
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped2)
res_exp2 = z.exp()
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, res_exp2)
def test_user_overidden_method_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
def sigmoid(self):
return None
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
x.sigmoid()
msg = (
"Accessing overridden method/attribute sigmoid on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_user_overidden_attr_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
ndim = 10
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_user_overidden_property_unsupported(self):
class LocalSubclass(torch.Tensor):
def __init__(self):
self._ndim = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@property
def ndim(self):
return self._ndim
@ndim.setter
def ndim(self, value):
self._ndim = value
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_overridden_method_guarding(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager")
def fn(x):
return x.sigmoid()
with torch._dynamo.config.patch(
error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
fn(x)
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
with torch._dynamo.config.patch(
traceable_tensor_subclasses={LocalSubclass}
), self.assertRaisesRegex(
TypeError,
"'bool' object is not callable",
):
LocalSubclass.sigmoid = False
fn(x)
def test_torch_function_call_on_attr(self):
x = torch.ones(2, 2)
wrapped = x.as_subclass(DummyNDim)
def fn(w):
return w.ndim + torch.ones(2)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, torch.ones(2) + 10)
def test_torch_function_wrapper_class(self):
x = torch.ones(2, 2)
wrapped = WrapperSubclass(x)
def fn(w):
return torch.add(w, 1.0)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
def test_torch_function_wrapper_class_with_kwargs(self):
x = torch.ones(2, 2)
wrapped = WrapperSubclass(x)
def fn(w):
return torch.add(w, 1.0, alpha=2.0)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
def test_tensor_subclass_custom_attr(self):
class AttrSubclass(torch.Tensor):
x: int = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.x + torch.ones(2, 2)
with traceable_subclass(AttrSubclass):
input = torch.ones(2, 2).as_subclass(AttrSubclass)
fn_opt = compile_full_eager(fn)
res_exp = fn(input)
res_act = fn_opt(input)
self.assertEqual(res_exp, res_act)
def test_compile_with_fake_tensor_dynamic_dim(self):
x = torch.randn([3, 4])
def f(x):
return torch.sin(x)
def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
x1 = torch.rand_like(x)
f(x)
f(torch.randn([4, 3]))
shape_env = ShapeEnv()
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
x_fake = fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
),
)
x1_fake = fake_mode.from_tensor(
x1,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
),
)
opt_f(x_fake)
opt_f(x1_fake)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
def test_compile_with_fake_tensor_automatic_dynamic(self):
def f(x):
return torch.sin(x)
def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
shape_env = ShapeEnv()
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
for inp in inps:
fake_inp = fake_mode.from_tensor(
inp,
symbolic_context=StatelessSymbolicContext(
[dim_dynamic for i in range(x.dim())]
),
)
opt_f(fake_inp)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
x = torch.randn([3, 4])
y = torch.randn([4, 5])
z = torch.randn([5, 6])
a = torch.randn([3, 5])
b = torch.randn([4, 4])
# When inputs' DimDynamic is DYNAMIC or DUCK, the inputs
# to opt_f will be tensors with SymInt sizes. Dynamo will treat input
# as dynamic automatically and will only compile once
for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]:
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1)
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1)
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1)
for dim_dynamic in [DimDynamic.STATIC]:
# Recompile once, first with dim 0 and 1 become Dynamic
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
# Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
# Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
def test_compile_with_functionalization(self):
x = torch.randn([3, 4])
x_clone = x.clone()
x_clone2 = x.clone()
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return x.add_(1.0) + torch.nn.functional.relu_(x)
f_out = f(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 3)
self.assertEqual(len(backend.graphs), 1)
self.assertEqual(len(backend.example_inputs), 1)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
add_ = l_x_.add_(1.0)
relu_ = torch.relu_(l_x_); l_x_ = None
add = add_ + relu_; add_ = relu_ = None
return (add,)
"""
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
self.assertEqual(actual, expected)
ff = torch.func.functionalize(f)
ff_out = ff(x_clone)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 6)
self.assertEqual(len(backend.graphs), 2)
self.assertEqual(len(backend.example_inputs), 2)
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
def to_fun(x):
x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
return x_functional
def aot_f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(to_fun, args)
func_kwargs = pytree.tree_map(to_fun, kwargs)
return func(*func_args, **func_kwargs)
finally:
torch._disable_functionalization()
return wrapper
aot_ff = aot_f_wrapper(f)
aot_ff_out = aot_ff(x_clone2)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
self.assertEqual(len(backend.graphs), 3)
self.assertEqual(len(backend.example_inputs), 3)
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
self.assertEqual(f_out, ff_out)
self.assertEqual(f_out, aot_ff_out)
try:
torch._enable_functionalization(reapply_views=False)
xf = pytree.tree_map(to_fun, x)
x_view = xf.t()
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
f(x_view)
finally:
torch._disable_functionalization()
def test_compile_higher_order_with_functionalization(self):
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return wrap(lambda x: x.add_(1.0), x)
def check_count_and_graph(
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
):
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
self.assertEqual(len(backend.graphs), exp_n_graph)
actual = normalize_gm(
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
)
self.assertExpectedInline(actual, exp_graph)
t = torch.randn([3, 4])
t_clone = t.clone()
t_clone2 = t.clone()
f(t)
expected_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
add_ = l_x_.add_(1.0); l_x_ = None
return (add_,)
"""
check_count_and_graph(1, 2, 1, expected_graph)
ff = torch.func.functionalize(f)
ff_out = ff(t_clone)
# frame count and op count are incremented due to re-compilation
check_count_and_graph(2, 4, 2, expected_graph)
try:
x = torch._to_functional_tensor(t_clone2)
torch._mirror_autograd_meta_to(t_clone2, x)
torch._enable_functionalization(reapply_views=False)
aot_f_out = f(x)
finally:
torch._disable_functionalization()
# frame count and op count are incremented due to re-compilation
check_count_and_graph(3, 6, 3, expected_graph)
def test_has_torch_function(self):
class MyTensor:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.max:
return torch.tensor(123)
return func(*args, **kwargs)
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
def fn(x):
return torch.overrides.has_torch_function_unary(
x
), torch.overrides.has_torch_function_variadic(x)
for test_class in [MyTensor, LocalSubclass]:
x = test_class()
ref0 = fn(x)
ref1 = fn(4)
opt_fn = torch._dynamo.optimize("eager")(fn)
res0 = opt_fn(x)
res1 = opt_fn(4)
self.assertEqual(ref0, res0)
self.assertEqual(ref1, res1)
def test_wrapper_subclass_guards_on_inner_tensor(self):
# Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
# Also adds additional guards on the inner tensor's sizes.
# When the first input to an op has x.shape[0] > 5, we insert an extra add node.
class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
# Double the outer-most dimension
outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
outer_shape,
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
)
def __init__(self, inner):
self.inner_elem = inner
def __tensor_flatten__(self):
return ["inner_elem"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
def __repr__(self):
return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args_inner = torch.utils._pytree.tree_map_only(
DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
)
out_inner = func(*args_inner, **kwargs)
# Add guards on the inner tensor's sizes
if args_inner[0].shape[0] > 3:
out_inner += 2
return DoubleSizeMaybeAddGeThreeTensor(out_inner)
curr_var_to_val = None
curr_var_to_sources = None
guards = None
def backend(gm, args):
context = torch._guards.TracingContext.get()
# Grab info on sources and guards from the shapeenv
nonlocal curr_var_to_val
nonlocal curr_var_to_sources
nonlocal guards
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
curr_var_to_val = {
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
}
curr_var_to_sources = {
str(k): v[0].name()
for k, v in context.fake_mode.shape_env.var_to_sources.items()
}
return gm
@torch.compile(backend=backend)
def fn(x):
if x.shape[0] < 10:
return torch.mul(x, x)
else:
return torch.div(x, x)
inp = torch.ones(4, 4)
x = DoubleSizeMaybeAddGeThreeTensor(inp)
torch._dynamo.mark_dynamic(x, 0)
res = fn(x)
# During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = {
"s0": 8,
"s1": 4,
}
expected_var_to_sources = {
"s0": "L['x'].size()[0]",
"s1": "L['x'].inner_elem.size()[0]",
}
self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertExpectedInline(
"\n".join(guards),
"""\
Eq(2*s1, s0)
2*s1 < 10
s1 > 3""",
)
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
# shouldn't recompile for different sizes when dynamic=True
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
# should recompile for different data size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# avoid recompile using manual mark_dynamic() for different data size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
# NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
# should recompile for different scale size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# still recompiles using manual mark_dynamic() on outer for different scale size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
# NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
def test_torch_dispatch_subclass_guard_recompile(self):
x = torch.ones(2, 2)
x_two = TwoTensor(x.clone(), x.clone())
def fn(w):
return torch.add(w, 1.0)
fn_opt = torch.compile(backend="eager")(fn)
ref = fn(x_two)
res = fn_opt(x_two)
self.assertEqual(ref, res)
# ensure no recompilation on same input type
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn_opt(TwoTensor(x + 1, x + 2))
# recompile!
ref = fn(x)
res = fn_opt(x)
self.assertEqual(ref, res)
def test_torch_function_subclass_survives_into_aot_autograd(self):
# If you have a tensor subclass that relies on dispatch into the same op
# without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),
# the torch function-ness will survive into AOTAutograd. Today, NestedTensor
# actually relies on this behavior! Because that torch function logic
# runs during AOTAutograd, this test tests that there is no logic below
# that relies torch function that gets unexpectedly disabled after we
# redispatch from the subclass's torch function.
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, t):
return torch.Tensor._make_wrapper_subclass(
cls,
t.shape,
t.stride(),
t.storage_offset(),
torch.contiguous_format,
t.dtype,
torch.strided,
t.device,
False,
t.requires_grad,
"sizes",
False,
False,
None,
)
def __init__(self, t):
super().__init__()
self._t = t
def __tensor_flatten__(self):
return ["_t"], {}
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
t = inner_tensors["_t"]
return SubTensor(t)
def __repr__(self):
return f"SubTensor({self._t})"
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args)
output = func(*new_args, **kwargs)
output = pytree.tree_map_only(
torch.Tensor, lambda t: SubTensor(t), output
)
return output
@torch.compile(dynamic=True)
def f(x):
return x.unflatten(-1, [2, 5])
s = SubTensor(torch.randn(3, 10))
f(s)
def test_recompile_with_symbool_inputs(self):
def f(pred: bool):
if pred:
return torch.ones([3, 4])
else:
return torch.ones([4, 3])
def test_recompilation(
f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
):
torch._dynamo.reset()
shape_env = ShapeEnv()
backend = torch._dynamo.testing.EagerAndRecordGraphs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
f_cond = torch.compile(f, backend=cnt, fullgraph=True)
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
fake_inp = fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
),
)
for i, size in enumerate(sizes):
pred = fake_inp.size(0) == size
f_cond(pred)
actual = normalize_gm(
backend.graphs[exp_frame_count[i] - 1].print_readable(
print_output=False
)
)
actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
self.assertExpectedInline(actual, exp_graphs[i])
self.assertEqual(cnt.frame_count, exp_frame_count[i])
self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones = torch.ones([3, 4])
return (ones,)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones = torch.ones([4, 3])
return (ones,)
"""
test_recompilation(
f,
torch.randn([3, 4]),
[3, 3, 4, 5],
exp_graphs=[true_graph, true_graph, false_graph, false_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
],
],
)
test_recompilation(
f,
torch.randn([3, 4]),
[4, 5, 3, 3],
exp_graphs=[false_graph, false_graph, true_graph, true_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
],
)
def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self):
def f(x_subclass):
tmp_subclass = torch.add(x, 1)
return torch.mul(tmp_subclass._scale, tmp_subclass._constant)
x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)
out_ref = f(x)
out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertEqual(out_ref, out_test)
def test_support_bases(self):
import abc
import torch.fx._symbolic_trace
class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
def __new__(cls, name, bases, dct):
x = super().__new__(cls, name, bases, dct)
x.attr = 100
return x
class Multistreamable(abc.ABC): # noqa: B024
pass
class Foo(Multistreamable, metaclass=Meta):
pass
@torch.compile(backend="eager", fullgraph=True)
def f(x):
typ = type(Foo())
typ.__bases__
return typ.__bases__
self.assertEqual(f(torch.randn(1)), (Multistreamable,))
@parametrize("dynamic", [False, True])
def test_subclass_views(self, dynamic):
def _get_views(t):
# Note that any closed-over SymInts will be symbolicized during fake-ification.
yield t.narrow(dim=-1, start=3, length=8)
yield t.split(5, -1)
yield t.split_with_sizes([9, 6], -1)
yield t.unsqueeze(-1).expand(4, 15, 10)
yield t.select(-1, 6)
yield t[2:3, 5:9]
def f(x):
return x * 2
compiled_f = torch.compile(
f, backend="aot_eager", fullgraph=True, dynamic=dynamic
)
# Take a view of a subclass to pass as input.
t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15))
for view in _get_views(t):
out_ref = f(view)
out_test = compiled_f(view)
self.assertEqual(out_ref, out_test)
instantiate_parametrized_tests(SubclassTests)
class TestNestedTensor(torch._dynamo.test_case.TestCase):
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
D = nested_size[1]
out = []
for s in nested_size[0]:
out.append(
torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
)
return jagged_from_list(out, offsets)
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
max_dim = (starts + lengths).max()
values_tensor = torch.randn(
starts.shape[0],
max_dim.item(),
inner_dim,
requires_grad=requires_grad,
dtype=torch.float64,
)
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
self.assertEqual(actual_recompiles, expected_recompiles)
def test_unary_does_not_recompile(self):
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None)
self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False)
def test_binary_does_not_recompile(self):
def binary(nt1, nt2):
if nt1.shape == nt2.shape:
return nt1 + nt2
else:
return nt1.sin()
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim
# should not always be equal. To avoid that, we use (3, j0, 5) here.
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
def test_binary_recompiles(self):
def binary(nt1, nt2):
if nt1.shape == nt2.shape:
return nt1 + nt2
else:
return nt1.sin()
# Binary recompiles because singleton ints no longer match
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
# TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
# TODO: Switch to public API when it exists
nt2, _ = jagged_from_list([a, b, c], nt.offsets())
def fn1(nt1, nt2):
return (nt1 + nt2).sin().cos()
compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
out = compiled_f(nt, nt2)
out_buffer = out.values()
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
out_ref = fn1(nt, nt2)
out_buffer_ref = out_ref.values()
ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
self.assertTrue(torch.allclose(ga, ga_ref))
self.assertTrue(torch.allclose(gb, gb_ref))
self.assertTrue(torch.allclose(gc, gc_ref))
def test_basic_autograd(self):
self._test_autograd("aot_eager")
@requires_cuda
def test_basic_autograd_inductor(self):
self._test_autograd("inductor")
def test_subclass_with_mutation_in_graph(self):
# In this graph, we have an in-graph mutation, i.e. a mutation that is allowed
# to remain in the graph. Normally this is allowed, but it's not allowed if
# the graph handles subclasses at all.
# Whether the mutation is allowed or not allowed in the graph alters the number
# of outputs from the forward graph. Previously, a bug in this handling meant
# that sometimes the expected number and actual number of outputs from the
# joint graph did not match, causing assertion failures.
def fn(x, y):
z = x.sin()
y.sin_()
return z.cos(), y.cos()
fn_c = torch.compile(fn, backend="inductor")
values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)]
values_copy = [x.detach().clone().requires_grad_(True) for x in values]
nt, offsets = jagged_from_list(values, None)
nt_copy, offsets = jagged_from_list(values_copy, offsets)
y = torch.rand((4, 8))
y_copy = y.clone()
ret = fn_c(nt, y)[0]
ref = fn(nt_copy, y_copy)[0]
self.assertEqual(ret.values(), ref.values())
ret.values().sum().backward()
ref.values().sum().backward()
for ref_v, res_v in zip(values_copy, values):
self.assertEqual(ref_v.grad, res_v.grad)
def test_unbind(self):
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim
# should not always be equal. To avoid that, we use (3, j0, 5) here.
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
def fn(x):
return x.unbind()
compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True)
out = compiled_f(nt)
out_ref = fn(nt)
# correctness
self.assertEqual(len(out), len(out_ref))
for x, x_ref in zip(out, out_ref):
self.assertTrue(torch.allclose(x, x_ref))
# We specialize on the length of offsets, e.g. (1) we recompile if the
# length of the offsets is different. (2) we don't recompile if the
# length of the offsets is the same, even if the size of the constituent
# tensors are different.
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)
def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
# Dense -> Subclass
for base_is_nt in [False, True]:
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
x = x.clone() if base_is_nt else x
self.assertEqual(x.is_leaf, False)
yield x.unsqueeze(-1)
# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
x, _ = self._get_jagged_tensor(
((2, 3, 4), 3), None, requires_grad=requires_grad_1
)
x = x.clone() if base_is_nt else x
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
yield x_view
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
x = x.clone() if base_is_nt else x
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
yield x_view_view
# Subclass -> Dense
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
yield x.values()
# Dense -> Subclass -> Dense -> Subclass
values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
offsets2 = offsets.clone().detach()
yield nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(), offsets
)
def _input_view_test(self, nt_view):
def fn(x):
return x.sin()
out_ref = fn(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend="aot_eager", dynamic=True
)
out = compile_fn(nt_view)
# Check metadata and values are correct
self.assertTrue(out.size() == out_ref.size())
self.assertTrue(out.stride() == out_ref.stride())
if out.is_nested:
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
else:
self.assertTrue(torch.allclose(out, out_ref))
# Check that no upper/lower bound guards are incurred
def backend(gm, args):
context = torch._guards.TracingContext.get()
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
# varies based on the type of view
guard_str = "\n".join(guards)
if isinstance(nt_view._base, NestedTensor):
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
else:
self.assertExpectedInline(guard_str, """""")
return gm
torch._dynamo.reset()
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
out = compile_fn(nt_view)
def test_inputs_to_compiled_fn_are_views(self):
for nt_view in self._get_views():
self._input_view_test(nt_view)
# NJT1 -> Dense -> NJT2 -> Dense view
# During view replay, the Dense -> NJT2 part will construct an intermediate,
# symbolically-sized NJT that is immediately deconstructed to return the final dense
# view. To construct this intermediate properly, we need the associated nested int
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
self._input_view_test(nt_view)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()