Handle redispatch correctly with tensor subclasses in ProxyTensor mode (#83122) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83122 Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b2e125d..f51248f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py
@@ -13,6 +13,7 @@ from torch._decomp import decomposition_table from torch.testing._internal.common_device_type import ops +from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule from torch.utils._pytree import tree_map from torch import nn @@ -132,6 +133,42 @@ else: return torch.rand_like(x) +""" +Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used +""" +class UnwrapTensor(torch.Tensor): + @staticmethod + def __new__(cls, tensor: torch.Tensor): + r = torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + dtype=tensor.dtype, + device=tensor.device, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + ) + r._tensor = tensor + return r + + def __repr__(self): + # TODO: consider all_gather the local tensors for better debugging + return f"UnwrapTensor({self._tensor})" + + __torch_function__ = _disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + ret = e + if isinstance(e, UnwrapTensor): + ret = e._tensor.cos() + + return ret + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + return func(*args, **kwargs) + class TestGenericProxyTensor(TestCase): # WARNING: if any of your inputs are index tensors, DO NOT use this # function @@ -467,6 +504,15 @@ torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) ) + def test_trace_subclasses(self): + def f(x): + x = UnwrapTensor(x) + y = x * 2 + return y + + inp = [torch.randn(5)] + self._test(f, [torch.randn(5)]) + class TestGenericProxyTensorReal(TestGenericProxyTensor): tracing_mode = "real" @@ -501,6 +547,7 @@ "test_make_fx_model_fwd_bwd", "test_proxy_tensor", "test_resnet18_backward_trace", + "test_trace_subclasses", ]) class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): tracing_mode = "symbolic"