| # Owner(s): ["module: dynamo"] |
| import functools |
| import itertools |
| import unittest |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| import torch._functorch.config |
| import torch.utils._pytree as pytree |
| import torch.utils.checkpoint |
| from torch._dynamo.testing import normalize_gm |
| from torch._higher_order_ops.wrap import wrap |
| |
| from torch.fx.experimental.symbolic_shapes import ( |
| DimDynamic, |
| ShapeEnv, |
| StatelessSymbolicContext, |
| ) |
| from torch.nested._internal.nested_tensor import ( |
| jagged_from_list, |
| jagged_from_tensor_and_lengths, |
| nested_view_from_values_offsets, |
| NestedTensor, |
| ) |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| subtest, |
| ) |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| from torch.testing._internal.two_tensor import TwoTensor |
| |
| |
| def traceable_subclass(c): |
| return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) |
| |
| |
| requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") |
| |
| compile_full_eager = torch.compile(backend="eager", fullgraph=True) |
| |
| |
| class BaseTorchFunction(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| |
| class MockSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| |
| class AttrSubclass(torch.Tensor): |
| x: int = 10 |
| size: int = 10 |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| return func(*args, **kwargs) |
| |
| |
| class DummyNDim(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| if func == torch.Tensor.ndim.__get__: |
| return 10 |
| |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| |
| class WrapperSubclass: |
| def __init__(self, tensor): |
| self.tensor = tensor |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) |
| kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) |
| |
| return func(*args, **kwargs) |
| |
| |
| class SigmoidToExpSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| if func == torch.Tensor.sigmoid: |
| return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) |
| |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| |
| # Wrapper subclass with two inner tensors: data and scale |
| # data has same shape as outer, and scale has single dim size |
| class ScaledTensor(torch.Tensor): |
| def __new__( |
| cls, |
| data: torch.Tensor, |
| scale: torch.Tensor, |
| *, |
| constant: int = 0, |
| ): |
| return torch.Tensor._make_wrapper_subclass( |
| cls, |
| data.size(), |
| strides=data.stride(), |
| storage_offset=data.storage_offset(), |
| dtype=data.dtype, |
| layout=data.layout, |
| requires_grad=data.requires_grad, |
| device=data.device, |
| ) |
| |
| def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): |
| self._data = data |
| self._scale = scale |
| self._constant = constant |
| |
| def __tensor_flatten__(self): |
| ctx = {"_constant": self._constant} |
| return ["_data", "_scale"], ctx |
| |
| @staticmethod |
| def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): |
| assert len(inner_tensors) == 2 |
| return ScaledTensor( |
| inner_tensors["_data"], |
| inner_tensors["_scale"], |
| constant=metadata["_constant"], |
| ) |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args, kwargs=None): |
| scaled_tensor = args[0] |
| out = func(scaled_tensor._data, *args[1:], **kwargs) |
| return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) |
| |
| def __repr__(self): |
| return f"{self._data.__repr__()}\n{self._scale.__repr__()}" |
| |
| |
| def func(a): |
| return a.sin() |
| |
| |
| class EagerRecordGraphAndInputs: |
| def __init__(self): |
| self.graphs = [] |
| self.example_inputs = [] |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| self.graphs.append(gm) |
| self.example_inputs.append(example_inputs) |
| return gm |
| |
| |
| GLOBAL_TEST_SUBCLASSES = { |
| MockSubclass, |
| DummyNDim, |
| SigmoidToExpSubclass, |
| BaseTorchFunction, |
| } |
| |
| |
| # Returns True if the function recompiles between inputs1 and inputs2 with the |
| # specified dynamic setting. |
| def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): |
| compile_count = [0] |
| |
| def counter(gm, example_inputs): |
| compile_count[0] += 1 |
| return gm |
| |
| compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) |
| compiled_f(*inputs1) |
| compiled_f(*inputs2) |
| return compile_count[0] > 1 |
| |
| |
| class SubclassTests(torch._dynamo.test_case.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls._exit_stack.enter_context( |
| torch._dynamo.config.patch( |
| "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES |
| ) |
| ) |
| |
| @classmethod |
| def tearDownClass(cls): |
| cls._exit_stack.close() |
| |
| def test_no_call_to_new(self): |
| class BadNewTorchFunction(torch.Tensor): |
| def __new__(cls, *args, **kwargs): |
| raise RuntimeError("Oops!") |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| with torch._dynamo.config.patch( |
| "traceable_tensor_subclasses", {BadNewTorchFunction} |
| ): |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return torch.add(x, 1) |
| |
| input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) |
| |
| res = fn(input) |
| self.assertIsInstance(res, BadNewTorchFunction) |
| |
| def test_base_torch_function_tracing(self): |
| def fn(x): |
| return torch.add(x, 1) |
| |
| input = torch.ones(2, 2).as_subclass(BaseTorchFunction) |
| out = fn(input) |
| out_opt = compile_full_eager(fn)(input) |
| self.assertIsInstance(out, BaseTorchFunction) |
| self.assertEqual(out, out_opt) |
| |
| def test_torch_function_state_graph_break(self): |
| @torch.compile(backend="eager") |
| def fn(x): |
| with torch._C.DisableTorchFunctionSubclass(): |
| torch._dynamo.graph_break() |
| return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| res, _ = fn(input) |
| self.assertFalse(res) |
| |
| def test_torch_function_state_nested(self): |
| @torch.compile(backend="eager") |
| def fn(x): |
| with torch._C.DisableTorchFunctionSubclass(): |
| with torch._C.DisableTorchFunctionSubclass(): |
| x = x + 1 |
| # Should reset to the outer state (disabled) after exiting ctx manager |
| return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| res, _ = fn(input) |
| self.assertFalse(res) |
| |
| def test_torch_function_state_tracing(self): |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| with torch._C.DisableTorchFunctionSubclass(): |
| torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| |
| def test_torch_function_state_guards(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def fn(x): |
| torch.add(x, 1.0) |
| |
| input = torch.ones(2, 2) |
| |
| with torch._C.DisableTorchFunctionSubclass(): |
| res = fn(input) |
| |
| res = fn(input) |
| |
| self.assertEqual(cnt.frame_count, 2) |
| |
| def test_return_subclass(self): |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return MockSubclass(torch.add(x, 1.0)) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| self.assertIsInstance(res, MockSubclass) |
| |
| def test_return_as_subclass(self): |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return torch.add(x, 1.0).as_subclass(MockSubclass) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| self.assertIsInstance(res, MockSubclass) |
| |
| def test_return_local_subclass(self): |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return LocalSubclass(torch.add(x, 1.0)) |
| |
| input = torch.ones(2, 2) |
| |
| res = fn(input) |
| self.assertIsInstance(res, LocalSubclass) |
| |
| @parametrize( |
| "comparison", |
| [ |
| subtest(isinstance, "isinstance"), |
| subtest(lambda instance, type_: type(instance) == type_, "equality"), |
| subtest(lambda instance, type_: type(instance) is type_, "identity"), |
| ], |
| ) |
| @parametrize( |
| "input_type", |
| [ |
| subtest(torch.Tensor, "tensor"), |
| subtest(DummyNDim, "subclass"), |
| ], |
| ) |
| def test_type_check(self, comparison, input_type): |
| with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): |
| |
| def fn(x): |
| if comparison(x, DummyNDim): |
| return torch.ones(1, 1) |
| else: |
| return torch.zeros(2, 2) |
| |
| input = torch.ones(2, 2).as_subclass(input_type) |
| exp_res = fn(input) |
| act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) |
| self.assertEqual(exp_res, act_res) |
| |
| def test_torch_function_call_on_method(self): |
| x = torch.ones(2, 2) |
| y = torch.ones(2, 2) |
| z = torch.ones(2, 2) |
| wrapped = x.as_subclass(SigmoidToExpSubclass) |
| wrapped2 = y.as_subclass(SigmoidToExpSubclass) |
| |
| def fn(w): |
| return w.sigmoid() |
| |
| fn_opt = compile_full_eager(fn) |
| |
| res_exp = fn(wrapped) |
| res_act = fn_opt(wrapped2) |
| res_exp2 = z.exp() |
| |
| self.assertEqual(res_exp, res_act) |
| self.assertEqual(res_exp, res_exp2) |
| |
| def test_user_overidden_method_unsupported(self): |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| def sigmoid(self): |
| return None |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| x.sigmoid() |
| |
| msg = ( |
| "Accessing overridden method/attribute sigmoid on a tensor" |
| " subclass with a __torch_function__ override is not supported" |
| ) |
| with torch._dynamo.config.patch( |
| "traceable_tensor_subclasses", {LocalSubclass} |
| ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): |
| x = torch.ones(2, 2).as_subclass(LocalSubclass) |
| fn(x) |
| |
| def test_user_overidden_attr_unsupported(self): |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| ndim = 10 |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return x.ndim |
| |
| msg = ( |
| "Accessing overridden method/attribute ndim on a tensor" |
| " subclass with a __torch_function__ override is not supported" |
| ) |
| with torch._dynamo.config.patch( |
| "traceable_tensor_subclasses", {LocalSubclass} |
| ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): |
| x = torch.ones(2, 2).as_subclass(LocalSubclass) |
| fn(x) |
| |
| def test_user_overidden_property_unsupported(self): |
| class LocalSubclass(torch.Tensor): |
| def __init__(self): |
| self._ndim = 10 |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| @property |
| def ndim(self): |
| return self._ndim |
| |
| @ndim.setter |
| def ndim(self, value): |
| self._ndim = value |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return x.ndim |
| |
| msg = ( |
| "Accessing overridden method/attribute ndim on a tensor" |
| " subclass with a __torch_function__ override is not supported" |
| ) |
| with torch._dynamo.config.patch( |
| "traceable_tensor_subclasses", {LocalSubclass} |
| ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): |
| x = torch.ones(2, 2).as_subclass(LocalSubclass) |
| fn(x) |
| |
| def test_overridden_method_guarding(self): |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| @torch.compile(backend="eager") |
| def fn(x): |
| return x.sigmoid() |
| |
| with torch._dynamo.config.patch( |
| error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} |
| ): |
| x = torch.ones(2, 2).as_subclass(LocalSubclass) |
| fn(x) |
| fn(x) |
| x = torch.ones(2, 2).as_subclass(LocalSubclass) |
| fn(x) |
| |
| with torch._dynamo.config.patch( |
| traceable_tensor_subclasses={LocalSubclass} |
| ), self.assertRaisesRegex( |
| TypeError, |
| "'bool' object is not callable", |
| ): |
| LocalSubclass.sigmoid = False |
| fn(x) |
| |
| def test_torch_function_call_on_attr(self): |
| x = torch.ones(2, 2) |
| wrapped = x.as_subclass(DummyNDim) |
| |
| def fn(w): |
| return w.ndim + torch.ones(2) |
| |
| fn_opt = compile_full_eager(fn) |
| |
| res_exp = fn(wrapped) |
| res_act = fn_opt(wrapped) |
| |
| self.assertEqual(res_exp, res_act) |
| self.assertEqual(res_exp, torch.ones(2) + 10) |
| |
| def test_torch_function_wrapper_class(self): |
| x = torch.ones(2, 2) |
| wrapped = WrapperSubclass(x) |
| |
| def fn(w): |
| return torch.add(w, 1.0) |
| |
| fn_opt = compile_full_eager(fn) |
| |
| res_exp = fn(wrapped) |
| res_act = fn_opt(wrapped) |
| self.assertEqual(res_exp, res_act) |
| |
| def test_torch_function_wrapper_class_with_kwargs(self): |
| x = torch.ones(2, 2) |
| wrapped = WrapperSubclass(x) |
| |
| def fn(w): |
| return torch.add(w, 1.0, alpha=2.0) |
| |
| fn_opt = compile_full_eager(fn) |
| |
| res_exp = fn(wrapped) |
| res_act = fn_opt(wrapped) |
| self.assertEqual(res_exp, res_act) |
| |
| def test_tensor_subclass_custom_attr(self): |
| class AttrSubclass(torch.Tensor): |
| x: int = 10 |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def fn(x): |
| return x.x + torch.ones(2, 2) |
| |
| with traceable_subclass(AttrSubclass): |
| input = torch.ones(2, 2).as_subclass(AttrSubclass) |
| fn_opt = compile_full_eager(fn) |
| |
| res_exp = fn(input) |
| res_act = fn_opt(input) |
| self.assertEqual(res_exp, res_act) |
| |
| def test_compile_with_fake_tensor_dynamic_dim(self): |
| x = torch.randn([3, 4]) |
| |
| def f(x): |
| return torch.sin(x) |
| |
| def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): |
| torch._dynamo.reset() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| opt_f = torch.compile(f, backend=cnt, fullgraph=True) |
| |
| x1 = torch.rand_like(x) |
| f(x) |
| f(torch.randn([4, 3])) |
| shape_env = ShapeEnv() |
| with torch._subclasses.fake_tensor.FakeTensorMode( |
| shape_env=shape_env |
| ) as fake_mode: |
| x_fake = fake_mode.from_tensor( |
| x, |
| symbolic_context=StatelessSymbolicContext( |
| dynamic_sizes=[dim_dynamic for i in range(x.dim())] |
| ), |
| ) |
| x1_fake = fake_mode.from_tensor( |
| x1, |
| symbolic_context=StatelessSymbolicContext( |
| dynamic_sizes=[dim_dynamic for i in range(x.dim())] |
| ), |
| ) |
| opt_f(x_fake) |
| opt_f(x1_fake) |
| |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| |
| test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) |
| test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) |
| test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) |
| |
| def test_compile_with_fake_tensor_automatic_dynamic(self): |
| def f(x): |
| return torch.sin(x) |
| |
| def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): |
| torch._dynamo.reset() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_f = torch.compile(f, backend=cnt, fullgraph=True) |
| |
| shape_env = ShapeEnv() |
| with torch._subclasses.fake_tensor.FakeTensorMode( |
| shape_env=shape_env |
| ) as fake_mode: |
| for inp in inps: |
| fake_inp = fake_mode.from_tensor( |
| inp, |
| symbolic_context=StatelessSymbolicContext( |
| [dim_dynamic for i in range(x.dim())] |
| ), |
| ) |
| opt_f(fake_inp) |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| |
| x = torch.randn([3, 4]) |
| y = torch.randn([4, 5]) |
| z = torch.randn([5, 6]) |
| a = torch.randn([3, 5]) |
| b = torch.randn([4, 4]) |
| # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs |
| # to opt_f will be tensors with SymInt sizes. Dynamo will treat input |
| # as dynamic automatically and will only compile once |
| for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: |
| test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) |
| test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) |
| test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) |
| |
| for dim_dynamic in [DimDynamic.STATIC]: |
| # Recompile once, first with dim 0 and 1 become Dynamic |
| test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) |
| # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. |
| test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) |
| # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. |
| test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) |
| |
| def test_compile_with_functionalization(self): |
| x = torch.randn([3, 4]) |
| x_clone = x.clone() |
| x_clone2 = x.clone() |
| backend = EagerRecordGraphAndInputs() |
| cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def f(x): |
| return x.add_(1.0) + torch.nn.functional.relu_(x) |
| |
| f_out = f(x) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertEqual(cnt.op_count, 3) |
| self.assertEqual(len(backend.graphs), 1) |
| self.assertEqual(len(backend.example_inputs), 1) |
| |
| expected = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| |
| add_ = l_x_.add_(1.0) |
| relu_ = torch.relu_(l_x_); l_x_ = None |
| add = add_ + relu_; add_ = relu_ = None |
| return (add,) |
| """ |
| actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| |
| ff = torch.func.functionalize(f) |
| ff_out = ff(x_clone) |
| |
| self.assertEqual(cnt.frame_count, 2) |
| self.assertEqual(cnt.op_count, 6) |
| self.assertEqual(len(backend.graphs), 2) |
| self.assertEqual(len(backend.example_inputs), 2) |
| actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) |
| |
| # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. |
| def to_fun(x): |
| x_functional = torch._to_functional_tensor(x) |
| torch._mirror_autograd_meta_to(x, x_functional) |
| return x_functional |
| |
| def aot_f_wrapper(func): |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| torch._enable_functionalization(reapply_views=False) |
| try: |
| func_args = pytree.tree_map(to_fun, args) |
| func_kwargs = pytree.tree_map(to_fun, kwargs) |
| return func(*func_args, **func_kwargs) |
| finally: |
| torch._disable_functionalization() |
| |
| return wrapper |
| |
| aot_ff = aot_f_wrapper(f) |
| aot_ff_out = aot_ff(x_clone2) |
| |
| self.assertEqual(cnt.frame_count, 3) |
| self.assertEqual(cnt.op_count, 9) |
| self.assertEqual(len(backend.graphs), 3) |
| self.assertEqual(len(backend.example_inputs), 3) |
| actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) |
| self.assertEqual(actual, expected) |
| self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) |
| |
| self.assertEqual(f_out, ff_out) |
| self.assertEqual(f_out, aot_ff_out) |
| |
| try: |
| torch._enable_functionalization(reapply_views=False) |
| xf = pytree.tree_map(to_fun, x) |
| x_view = xf.t() |
| with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): |
| f(x_view) |
| finally: |
| torch._disable_functionalization() |
| |
| def test_compile_higher_order_with_functionalization(self): |
| backend = EagerRecordGraphAndInputs() |
| cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) |
| |
| @torch.compile(backend=cnt, fullgraph=True) |
| def f(x): |
| return wrap(lambda x: x.add_(1.0), x) |
| |
| def check_count_and_graph( |
| exp_frame_count, exp_op_count, exp_n_graph, exp_graph |
| ): |
| self.assertEqual(cnt.frame_count, exp_frame_count) |
| self.assertEqual(cnt.op_count, exp_op_count) |
| self.assertEqual(len(backend.graphs), exp_n_graph) |
| actual = normalize_gm( |
| backend.graphs[exp_n_graph - 1].print_readable(print_output=False) |
| ) |
| self.assertExpectedInline(actual, exp_graph) |
| |
| t = torch.randn([3, 4]) |
| t_clone = t.clone() |
| t_clone2 = t.clone() |
| f(t) |
| |
| expected_graph = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| |
| wrap_body_0 = self.wrap_body_0 |
| wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None |
| getitem = wrap[0]; wrap = None |
| return (getitem,) |
| |
| class GraphModule(torch.nn.Module): |
| def forward(self, l_x_): |
| add_ = l_x_.add_(1.0); l_x_ = None |
| return (add_,) |
| """ |
| check_count_and_graph(1, 2, 1, expected_graph) |
| |
| ff = torch.func.functionalize(f) |
| ff_out = ff(t_clone) |
| # frame count and op count are incremented due to re-compilation |
| check_count_and_graph(2, 4, 2, expected_graph) |
| |
| try: |
| x = torch._to_functional_tensor(t_clone2) |
| torch._mirror_autograd_meta_to(t_clone2, x) |
| torch._enable_functionalization(reapply_views=False) |
| aot_f_out = f(x) |
| finally: |
| torch._disable_functionalization() |
| |
| # frame count and op count are incremented due to re-compilation |
| check_count_and_graph(3, 6, 3, expected_graph) |
| |
| def test_has_torch_function(self): |
| class MyTensor: |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| if func is torch.max: |
| return torch.tensor(123) |
| return func(*args, **kwargs) |
| |
| class LocalSubclass(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| def fn(x): |
| return torch.overrides.has_torch_function_unary( |
| x |
| ), torch.overrides.has_torch_function_variadic(x) |
| |
| for test_class in [MyTensor, LocalSubclass]: |
| x = test_class() |
| ref0 = fn(x) |
| ref1 = fn(4) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res0 = opt_fn(x) |
| res1 = opt_fn(4) |
| self.assertEqual(ref0, res0) |
| self.assertEqual(ref1, res1) |
| |
| def test_wrapper_subclass_guards_on_inner_tensor(self): |
| # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. |
| # Also adds additional guards on the inner tensor's sizes. |
| # When the first input to an op has x.shape[0] > 5, we insert an extra add node. |
| class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): |
| @staticmethod |
| def __new__(cls, inner): |
| # Double the outer-most dimension |
| outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] |
| return torch.Tensor._make_wrapper_subclass( |
| # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. |
| # Calling the overload that has kwargs causes us to go down the first overload path, |
| # which will **always** specialize sizes. |
| # We should probably eventually fix this so that the first overload can just handle dynamic shapes. |
| cls, |
| outer_shape, |
| inner.stride(), |
| None, |
| None, |
| inner.dtype, |
| inner.layout, |
| inner.device, |
| False, |
| inner.requires_grad, |
| ) |
| |
| def __init__(self, inner): |
| self.inner_elem = inner |
| |
| def __tensor_flatten__(self): |
| return ["inner_elem"], None |
| |
| @staticmethod |
| def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): |
| return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) |
| |
| def __repr__(self): |
| return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| args_inner = torch.utils._pytree.tree_map_only( |
| DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args |
| ) |
| out_inner = func(*args_inner, **kwargs) |
| |
| # Add guards on the inner tensor's sizes |
| if args_inner[0].shape[0] > 3: |
| out_inner += 2 |
| |
| return DoubleSizeMaybeAddGeThreeTensor(out_inner) |
| |
| curr_var_to_val = None |
| curr_var_to_sources = None |
| guards = None |
| |
| def backend(gm, args): |
| context = torch._guards.TracingContext.get() |
| |
| # Grab info on sources and guards from the shapeenv |
| nonlocal curr_var_to_val |
| nonlocal curr_var_to_sources |
| nonlocal guards |
| |
| guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] |
| curr_var_to_val = { |
| str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() |
| } |
| curr_var_to_sources = { |
| str(k): v[0].name() |
| for k, v in context.fake_mode.shape_env.var_to_sources.items() |
| } |
| return gm |
| |
| @torch.compile(backend=backend) |
| def fn(x): |
| if x.shape[0] < 10: |
| return torch.mul(x, x) |
| else: |
| return torch.div(x, x) |
| |
| inp = torch.ones(4, 4) |
| |
| x = DoubleSizeMaybeAddGeThreeTensor(inp) |
| torch._dynamo.mark_dynamic(x, 0) |
| res = fn(x) |
| # During fakeifying, we end up allocating a separate symint |
| # for the outer and inner tensor (in this test, s0 is unused). |
| expected_var_to_val = { |
| "s0": 8, |
| "s1": 4, |
| } |
| expected_var_to_sources = { |
| "s0": "L['x'].size()[0]", |
| "s1": "L['x'].inner_elem.size()[0]", |
| } |
| self.assertEqual(curr_var_to_val, expected_var_to_val) |
| self.assertEqual(curr_var_to_sources, expected_var_to_sources) |
| self.assertExpectedInline( |
| "\n".join(guards), |
| """\ |
| Eq(2*s1, s0) |
| 2*s1 < 10 |
| s1 > 3""", |
| ) |
| |
| def test_wrapper_subclass_with_same_sized_inner_tensor(self): |
| # shouldn't recompile for different sizes when dynamic=True |
| sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) |
| sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) |
| self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) |
| |
| # should recompile for different data size when dynamic=False |
| sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) |
| sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) |
| self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) |
| |
| # avoid recompile using manual mark_dynamic() for different data size |
| sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) |
| # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size |
| torch._dynamo.mark_dynamic(sub1, 0) |
| torch._dynamo.mark_dynamic(sub1, 1) |
| sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) |
| self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) |
| |
| def test_wrapper_subclass_with_differently_sized_inner_tensor(self): |
| # should recompile for different scale size when dynamic=False |
| sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) |
| sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) |
| self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) |
| |
| # still recompiles using manual mark_dynamic() on outer for different scale size |
| sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) |
| # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size |
| torch._dynamo.mark_dynamic(sub1, 0) |
| torch._dynamo.mark_dynamic(sub1, 1) |
| sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) |
| self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) |
| |
| def test_torch_dispatch_subclass_guard_recompile(self): |
| x = torch.ones(2, 2) |
| x_two = TwoTensor(x.clone(), x.clone()) |
| |
| def fn(w): |
| return torch.add(w, 1.0) |
| |
| fn_opt = torch.compile(backend="eager")(fn) |
| |
| ref = fn(x_two) |
| res = fn_opt(x_two) |
| self.assertEqual(ref, res) |
| |
| # ensure no recompilation on same input type |
| with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): |
| fn_opt(TwoTensor(x + 1, x + 2)) |
| |
| # recompile! |
| ref = fn(x) |
| res = fn_opt(x) |
| self.assertEqual(ref, res) |
| |
| def test_torch_function_subclass_survives_into_aot_autograd(self): |
| # If you have a tensor subclass that relies on dispatch into the same op |
| # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), |
| # the torch function-ness will survive into AOTAutograd. Today, NestedTensor |
| # actually relies on this behavior! Because that torch function logic |
| # runs during AOTAutograd, this test tests that there is no logic below |
| # that relies torch function that gets unexpectedly disabled after we |
| # redispatch from the subclass's torch function. |
| class SubTensor(torch.Tensor): |
| @staticmethod |
| def __new__(cls, t): |
| return torch.Tensor._make_wrapper_subclass( |
| cls, |
| t.shape, |
| t.stride(), |
| t.storage_offset(), |
| torch.contiguous_format, |
| t.dtype, |
| torch.strided, |
| t.device, |
| False, |
| t.requires_grad, |
| "sizes", |
| False, |
| False, |
| None, |
| ) |
| |
| def __init__(self, t): |
| super().__init__() |
| self._t = t |
| |
| def __tensor_flatten__(self): |
| return ["_t"], {} |
| |
| @staticmethod |
| def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): |
| t = inner_tensors["_t"] |
| return SubTensor(t) |
| |
| def __repr__(self): |
| return f"SubTensor({self._t})" |
| |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| with torch._C.DisableTorchFunctionSubclass(): |
| return func(*args, **kwargs) |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| kwargs = {} if kwargs is None else kwargs |
| new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) |
| output = func(*new_args, **kwargs) |
| output = pytree.tree_map_only( |
| torch.Tensor, lambda t: SubTensor(t), output |
| ) |
| return output |
| |
| @torch.compile(dynamic=True) |
| def f(x): |
| return x.unflatten(-1, [2, 5]) |
| |
| s = SubTensor(torch.randn(3, 10)) |
| f(s) |
| |
| def test_recompile_with_symbool_inputs(self): |
| def f(pred: bool): |
| if pred: |
| return torch.ones([3, 4]) |
| else: |
| return torch.ones([4, 3]) |
| |
| def test_recompilation( |
| f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards |
| ): |
| torch._dynamo.reset() |
| shape_env = ShapeEnv() |
| backend = torch._dynamo.testing.EagerAndRecordGraphs() |
| cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) |
| f_cond = torch.compile(f, backend=cnt, fullgraph=True) |
| with torch._subclasses.fake_tensor.FakeTensorMode( |
| shape_env=shape_env |
| ) as fake_mode: |
| fake_inp = fake_mode.from_tensor( |
| x, |
| symbolic_context=StatelessSymbolicContext( |
| dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] |
| ), |
| ) |
| for i, size in enumerate(sizes): |
| pred = fake_inp.size(0) == size |
| f_cond(pred) |
| actual = normalize_gm( |
| backend.graphs[exp_frame_count[i] - 1].print_readable( |
| print_output=False |
| ) |
| ) |
| actual_guard_str = [str(guard.expr) for guard in shape_env.guards] |
| self.assertExpectedInline(actual, exp_graphs[i]) |
| self.assertEqual(cnt.frame_count, exp_frame_count[i]) |
| self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) |
| |
| true_graph = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self): |
| ones = torch.ones([3, 4]) |
| return (ones,) |
| """ |
| false_graph = """\ |
| class GraphModule(torch.nn.Module): |
| def forward(self): |
| ones = torch.ones([4, 3]) |
| return (ones,) |
| """ |
| test_recompilation( |
| f, |
| torch.randn([3, 4]), |
| [3, 3, 4, 5], |
| exp_graphs=[true_graph, true_graph, false_graph, false_graph], |
| exp_frame_count=[1, 1, 2, 2], |
| exp_shape_env_guards=[ |
| [], |
| # s0 is specialized and guarded in outter shape_env when dynamo checks the guards |
| ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], |
| [ |
| "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", |
| "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", |
| ], |
| [ |
| "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", |
| "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", |
| "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", |
| ], |
| ], |
| ) |
| |
| test_recompilation( |
| f, |
| torch.randn([3, 4]), |
| [4, 5, 3, 3], |
| exp_graphs=[false_graph, false_graph, true_graph, true_graph], |
| exp_frame_count=[1, 1, 2, 2], |
| exp_shape_env_guards=[ |
| [], |
| # s0 is specialized and guarded in outter shape_env when dynamo checks the guards |
| ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], |
| [ |
| "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", |
| "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", |
| ], |
| [ |
| "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", |
| "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", |
| "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", |
| ], |
| ], |
| ) |
| |
| def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): |
| def f(x_subclass): |
| tmp_subclass = torch.add(x, 1) |
| return torch.mul(tmp_subclass._scale, tmp_subclass._constant) |
| |
| x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) |
| out_ref = f(x) |
| out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) |
| self.assertEqual(out_ref, out_test) |
| |
| def test_support_bases(self): |
| import abc |
| |
| import torch.fx._symbolic_trace |
| |
| class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): |
| def __new__(cls, name, bases, dct): |
| x = super().__new__(cls, name, bases, dct) |
| x.attr = 100 |
| return x |
| |
| class Multistreamable(abc.ABC): # noqa: B024 |
| pass |
| |
| class Foo(Multistreamable, metaclass=Meta): |
| pass |
| |
| @torch.compile(backend="eager", fullgraph=True) |
| def f(x): |
| typ = type(Foo()) |
| typ.__bases__ |
| return typ.__bases__ |
| |
| self.assertEqual(f(torch.randn(1)), (Multistreamable,)) |
| |
| @parametrize("dynamic", [False, True]) |
| def test_subclass_views(self, dynamic): |
| def _get_views(t): |
| # Note that any closed-over SymInts will be symbolicized during fake-ification. |
| yield t.narrow(dim=-1, start=3, length=8) |
| yield t.split(5, -1) |
| yield t.split_with_sizes([9, 6], -1) |
| yield t.unsqueeze(-1).expand(4, 15, 10) |
| yield t.select(-1, 6) |
| yield t[2:3, 5:9] |
| |
| def f(x): |
| return x * 2 |
| |
| compiled_f = torch.compile( |
| f, backend="aot_eager", fullgraph=True, dynamic=dynamic |
| ) |
| |
| # Take a view of a subclass to pass as input. |
| t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) |
| for view in _get_views(t): |
| out_ref = f(view) |
| out_test = compiled_f(view) |
| self.assertEqual(out_ref, out_test) |
| |
| |
| instantiate_parametrized_tests(SubclassTests) |
| |
| |
| class TestNestedTensor(torch._dynamo.test_case.TestCase): |
| def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): |
| # Makes a jagged tensor with N constituent tensors with size |
| # as specified ((S0, S1, S2), D) |
| D = nested_size[1] |
| out = [] |
| for s in nested_size[0]: |
| out.append( |
| torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64) |
| ) |
| return jagged_from_list(out, offsets) |
| |
| def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): |
| # Makes a jagged tensor with N constituent tensors with size |
| # as specified ((S0, S1, S2), D) |
| max_dim = (starts + lengths).max() |
| values_tensor = torch.randn( |
| starts.shape[0], |
| max_dim.item(), |
| inner_dim, |
| requires_grad=requires_grad, |
| dtype=torch.float64, |
| ) |
| return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) |
| |
| def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): |
| actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) |
| self.assertEqual(actual_recompiles, expected_recompiles) |
| |
| def test_unary_does_not_recompile(self): |
| nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) |
| nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) |
| self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) |
| |
| def test_binary_does_not_recompile(self): |
| def binary(nt1, nt2): |
| if nt1.shape == nt2.shape: |
| return nt1 + nt2 |
| else: |
| return nt1.sin() |
| |
| # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). |
| # This causes a recompile later on when it realizes the batch and last dim |
| # should not always be equal. To avoid that, we use (3, j0, 5) here. |
| nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) |
| nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) |
| nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) |
| nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) |
| self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) |
| |
| def test_binary_recompiles(self): |
| def binary(nt1, nt2): |
| if nt1.shape == nt2.shape: |
| return nt1 + nt2 |
| else: |
| return nt1.sin() |
| |
| # Binary recompiles because singleton ints no longer match |
| nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) |
| nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) |
| nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) |
| self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) |
| |
| # TODO: cannot parametrize this test class with device for some reason |
| def _test_autograd(self, backend): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) |
| nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) |
| # TODO: Switch to public API when it exists |
| nt2, _ = jagged_from_list([a, b, c], nt.offsets()) |
| |
| def fn1(nt1, nt2): |
| return (nt1 + nt2).sin().cos() |
| |
| compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) |
| out = compiled_f(nt, nt2) |
| out_buffer = out.values() |
| ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) |
| |
| out_ref = fn1(nt, nt2) |
| out_buffer_ref = out_ref.values() |
| ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) |
| |
| self.assertTrue(torch.allclose(ga, ga_ref)) |
| self.assertTrue(torch.allclose(gb, gb_ref)) |
| self.assertTrue(torch.allclose(gc, gc_ref)) |
| |
| def test_basic_autograd(self): |
| self._test_autograd("aot_eager") |
| |
| @requires_cuda |
| def test_basic_autograd_inductor(self): |
| self._test_autograd("inductor") |
| |
| def test_subclass_with_mutation_in_graph(self): |
| # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed |
| # to remain in the graph. Normally this is allowed, but it's not allowed if |
| # the graph handles subclasses at all. |
| # Whether the mutation is allowed or not allowed in the graph alters the number |
| # of outputs from the forward graph. Previously, a bug in this handling meant |
| # that sometimes the expected number and actual number of outputs from the |
| # joint graph did not match, causing assertion failures. |
| def fn(x, y): |
| z = x.sin() |
| y.sin_() |
| return z.cos(), y.cos() |
| |
| fn_c = torch.compile(fn, backend="inductor") |
| |
| values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] |
| values_copy = [x.detach().clone().requires_grad_(True) for x in values] |
| |
| nt, offsets = jagged_from_list(values, None) |
| nt_copy, offsets = jagged_from_list(values_copy, offsets) |
| y = torch.rand((4, 8)) |
| y_copy = y.clone() |
| |
| ret = fn_c(nt, y)[0] |
| ref = fn(nt_copy, y_copy)[0] |
| |
| self.assertEqual(ret.values(), ref.values()) |
| |
| ret.values().sum().backward() |
| ref.values().sum().backward() |
| for ref_v, res_v in zip(values_copy, values): |
| self.assertEqual(ref_v.grad, res_v.grad) |
| |
| def test_unbind(self): |
| # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). |
| # This causes a recompile later on when it realizes the batch and last dim |
| # should not always be equal. To avoid that, we use (3, j0, 5) here. |
| nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) |
| nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) |
| nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) |
| |
| def fn(x): |
| return x.unbind() |
| |
| compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) |
| out = compiled_f(nt) |
| |
| out_ref = fn(nt) |
| |
| # correctness |
| self.assertEqual(len(out), len(out_ref)) |
| for x, x_ref in zip(out, out_ref): |
| self.assertTrue(torch.allclose(x, x_ref)) |
| |
| # We specialize on the length of offsets, e.g. (1) we recompile if the |
| # length of the offsets is different. (2) we don't recompile if the |
| # length of the offsets is the same, even if the size of the constituent |
| # tensors are different. |
| self._check_recompiles(fn, (nt,), (nt2,), False) |
| self._check_recompiles(fn, (nt,), (nt3,), True) |
| |
| def _get_views(self): |
| # Test all cases with both an NT base and a dense base |
| # Subclass -> Subclass |
| # Dense -> Subclass |
| for base_is_nt in [False, True]: |
| # There are three cases to consider here based on the logic in |
| # meta_utils.py |
| # |
| # (1) basic case: |
| # view is not a leaf and has the same requires grad as its basic case |
| x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) |
| x = x.clone() if base_is_nt else x |
| self.assertEqual(x.is_leaf, False) |
| yield x.unsqueeze(-1) |
| |
| # (2) leaf view case: |
| # the view has to be a leaf (w/ requires_grad True or requires_grad False) |
| # base w/ requires_grad True or requires_grad False |
| for requires_grad_1, requires_grad_2 in itertools.product( |
| [True, False], repeat=2 |
| ): |
| x, _ = self._get_jagged_tensor( |
| ((2, 3, 4), 3), None, requires_grad=requires_grad_1 |
| ) |
| x = x.clone() if base_is_nt else x |
| with torch.no_grad(): |
| x_view = x.unsqueeze(-1) |
| # The issue is this doesn't quite work |
| x_view.requires_grad_(requires_grad_2) |
| yield x_view |
| |
| # (3) obscure case: |
| # view is not a leaf (implies requires_grad True) |
| # base w/ requires_grad False) |
| x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) |
| x = x.clone() if base_is_nt else x |
| # intermediate leaf view |
| with torch.no_grad(): |
| x_view = x.unsqueeze(-1) |
| x_view.requires_grad_(True) |
| x_view_view = x_view.unsqueeze(-1) |
| yield x_view_view |
| |
| # Subclass -> Dense |
| x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() |
| yield x.values() |
| |
| # Dense -> Subclass -> Dense -> Subclass |
| values = torch.randn(10, 5) |
| offsets = torch.tensor([0, 3, 6, 10]) |
| offsets2 = offsets.clone().detach() |
| yield nested_view_from_values_offsets( |
| nested_view_from_values_offsets(values, offsets).values(), offsets |
| ) |
| |
| def _input_view_test(self, nt_view): |
| def fn(x): |
| return x.sin() |
| |
| out_ref = fn(nt_view) |
| torch._dynamo.reset() |
| compile_fn = torch.compile( |
| fn, fullgraph=True, backend="aot_eager", dynamic=True |
| ) |
| out = compile_fn(nt_view) |
| |
| # Check metadata and values are correct |
| self.assertTrue(out.size() == out_ref.size()) |
| self.assertTrue(out.stride() == out_ref.stride()) |
| if out.is_nested: |
| self.assertTrue(torch.allclose(out.values(), out_ref.values())) |
| else: |
| self.assertTrue(torch.allclose(out, out_ref)) |
| |
| # Check that no upper/lower bound guards are incurred |
| def backend(gm, args): |
| context = torch._guards.TracingContext.get() |
| guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] |
| |
| # varies based on the type of view |
| guard_str = "\n".join(guards) |
| if isinstance(nt_view._base, NestedTensor): |
| self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") |
| else: |
| self.assertExpectedInline(guard_str, """""") |
| return gm |
| |
| torch._dynamo.reset() |
| compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) |
| out = compile_fn(nt_view) |
| |
| def test_inputs_to_compiled_fn_are_views(self): |
| for nt_view in self._get_views(): |
| self._input_view_test(nt_view) |
| |
| # NJT1 -> Dense -> NJT2 -> Dense view |
| # During view replay, the Dense -> NJT2 part will construct an intermediate, |
| # symbolically-sized NJT that is immediately deconstructed to return the final dense |
| # view. To construct this intermediate properly, we need the associated nested int |
| # to be symbolic. This view is expected to fail compilation until symbolic nested ints |
| # are cached onto fake offsets to solve this problem. |
| @unittest.expectedFailure |
| def test_subclass_dense_subclass_dense_view(self): |
| x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() |
| offsets2 = x.offsets().clone().detach() |
| nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() |
| self._input_view_test(nt_view) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |