Handle redispatch correctly with tensor subclasses in ProxyTensor mode (#83122)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83122
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index b2e125d..f51248f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -13,6 +13,7 @@
 
 from torch._decomp import decomposition_table
 from torch.testing._internal.common_device_type import ops
+from torch._C import _disabled_torch_function_impl
 from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
 from torch.utils._pytree import tree_map
 from torch import nn
@@ -132,6 +133,42 @@
     else:
         return torch.rand_like(x)
 
+"""
+Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
+"""
+class UnwrapTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, tensor: torch.Tensor):
+        r = torch.Tensor._make_wrapper_subclass(
+            cls,
+            tensor.size(),
+            dtype=tensor.dtype,
+            device=tensor.device,
+            layout=tensor.layout,
+            requires_grad=tensor.requires_grad,
+        )
+        r._tensor = tensor
+        return r
+
+    def __repr__(self):
+        # TODO: consider all_gather the local tensors for better debugging
+        return f"UnwrapTensor({self._tensor})"
+
+    __torch_function__ = _disabled_torch_function_impl
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        def unwrap(e):
+            ret = e
+            if isinstance(e, UnwrapTensor):
+                ret = e._tensor.cos()
+
+            return ret
+
+        args = tree_map(unwrap, args)
+        kwargs = tree_map(unwrap, kwargs)
+        return func(*args, **kwargs)
+
 class TestGenericProxyTensor(TestCase):
     # WARNING: if any of your inputs are index tensors, DO NOT use this
     # function
@@ -467,6 +504,15 @@
             torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
         )
 
+    def test_trace_subclasses(self):
+        def f(x):
+            x = UnwrapTensor(x)
+            y = x * 2
+            return y
+
+        inp = [torch.randn(5)]
+        self._test(f, [torch.randn(5)])
+
 
 class TestGenericProxyTensorReal(TestGenericProxyTensor):
     tracing_mode = "real"
@@ -501,6 +547,7 @@
     "test_make_fx_model_fwd_bwd",
     "test_proxy_tensor",
     "test_resnet18_backward_trace",
+    "test_trace_subclasses",
 ])
 class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
     tracing_mode = "symbolic"