switched over to using faketensor in proxytensor (#79634)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79634
Approved by: https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 40bec21..76bd4ae 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -7,6 +7,7 @@
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_methods_invocations import DecorateInfo
 from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
+from torch._subclasses.fake_tensor import DynamicOutputShapeException
 
 from torch.testing._internal.common_device_type import ops
 from torch.fx.experimental.proxy_tensor import make_fx
@@ -59,7 +60,7 @@
 
 
 class TestProxyTensor(TestCase):
-    def test_make_fx(self, device):
+    def test_make_fx_simple(self, device):
         def f(x):
             return torch.sin(x)
         inp = torch.randn(3)
@@ -110,6 +111,17 @@
             assert inp.grad is None
             torch.testing.assert_close(traced_graph_out, f(inp))
 
+    def test_inplace_metadata(self):
+        def f(x):
+            x = x.clone()
+            x.unsqueeze_(-1)
+            assert x.shape[-1] == 1
+            return x
+
+        inps = [torch.randn(5)]
+        fx_f = make_fx(f)(*inps)
+        self.assertEqual(fx_f(*inps), f(*inps))
+
     def test_mode_tracing_factory_function(self):
         def f(x):
             return x + torch.randn(x.shape)
@@ -136,6 +148,7 @@
         )
 
 make_fx_failures = {
+    # unknown
     xfail('allclose'),
     xfail('equal'),
     xfail('linalg.eigvals'),
@@ -150,6 +163,7 @@
     skip('nn.functional.max_unpool2d', '', device_type='cpu'),
     skip('nn.functional.max_unpool3d', '', device_type='cpu'),
     skip('linalg.lstsq'),  # flaky, probably just a precision issue
+
     # data-dependent control flow
     xfail('cov'),
     xfail('istft'),
@@ -182,35 +196,67 @@
     # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
     xfail('sparse.sampled_addmm'),
 
-    # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
+    # ???
     xfail('nn.functional.ctc_loss'),
+    # Sparse tensors are not supported with faketensors for now
+    xfail('to_sparse'),
+    # segfaults
+    skip('block_diag'),
+}
+
+fake_tensor_failures = {
+    # Needs complex-value support
+    xfail('polar'),
+    xfail('complex'),
+    xfail('linalg.eig'),
+    # FakeTensor fallback doesn't work
+    xfail('linalg.matrix_power'),
+    xfail('segment_reduce', 'lengths'),
+    xfail('multinomial'),
+    xfail('mvlgamma', 'mvlgamma_p_1'),
+    xfail('mvlgamma', 'mvlgamma_p_3'),
+    xfail('mvlgamma', 'mvlgamma_p_5'),
+    xfail('cholesky'),
+    xfail('cholesky_inverse'),
+    # ASAN failures due to divide by 0
+    skip('nn.functional.nll_loss'),
 }
 
 
+def _test_make_fx_helper(self, device, dtype, op, use_fake):
+    def f(args, kwargs):
+        return op.op(*args, **kwargs)
+    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
+    new_f = None
+    for sample_input in sample_inputs_itr:
+        args = [sample_input.input] + list(sample_input.args)
+        kwargs = sample_input.kwargs
+
+        try:
+            new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
+        except DynamicOutputShapeException as e:
+            self.skipTest("Dynamic output shape operation in trace")
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
+                arg.uniform_(0, 1)
+        try:
+            old_out = f(args, kwargs)
+        except Exception:
+            continue
+        new_out = wrapper_set_seed(new_f, args, kwargs)
+        self.assertEqual(new_out, old_out)
+
 class TestProxyTensorOpInfo(TestCase):
     @ops(op_db, allowed_dtypes=(torch.float,))
-    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
-             )
+    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
     def test_make_fx_exhaustive(self, device, dtype, op):
+        _test_make_fx_helper(self, device, dtype, op, False)
 
-        def f(args, kwargs):
-            return op.op(*args, **kwargs)
-        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
-        new_f = None
-        for sample_input in sample_inputs_itr:
-            args = [sample_input.input] + list(sample_input.args)
-            kwargs = sample_input.kwargs
-
-            new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
-            for arg in args:
-                if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
-                    arg.uniform_(0, 1)
-            try:
-                old_out = f(args, kwargs)
-            except Exception:
-                continue
-            new_out = wrapper_set_seed(new_f, args, kwargs)
-            self.assertEqual(new_out, old_out)
+    @ops(op_db, allowed_dtypes=(torch.float,))
+    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
+    def test_make_fx_fake_exhaustive(self, device, dtype, op):
+        _test_make_fx_helper(self, device, dtype, op, True)
 
 
 
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 441fe4a..82559ec 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -9,12 +9,13 @@
 from torch._C import _disabled_torch_function_impl
 import torch.utils._pytree as pytree
 from torch.fx import Tracer, GraphModule
+from torch._subclasses.fake_tensor import FakeTensorMode
 import torch.fx as fx
 from torch.utils._mode_utils import no_dispatch
 from torch.fx.passes.shape_prop import _extract_tensor_metadata
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
 
-from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
+from torch.utils._python_dispatch import TorchDispatchMode
 
 __all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
 aten = torch.ops.aten
@@ -22,7 +23,6 @@
 CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
 
 
-
 @contextmanager
 def decompose(decomposition_table):
     global CURRENT_DECOMPOSITION_TABLE
@@ -39,9 +39,9 @@
     global IS_STRICT
     IS_STRICT = val
 
-def wrap_output(real_out, proxy_out):
+def wrap_output(inner_res, proxy_res):
     def wrap_with_proxy(e, proxy):
-        if type(e) == torch.Tensor:
+        if isinstance(e, torch.Tensor):
             with no_dispatch():
                 return ProxyTensor(e, proxy)
         else:
@@ -50,17 +50,19 @@
     # Unfortunately, tree_map cannot directly be used here. As the resulting
     # object may be a proxy that represents a tuple, we may need to
     # explicitly unwrap the proxy by simulating the flattening operations.
-    if isinstance(real_out, tuple):
-        return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
-    elif isinstance(real_out, list):
-        return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
-    elif isinstance(real_out, torch.Tensor):
-        return wrap_with_proxy(real_out, proxy_out)
+    if isinstance(inner_res, tuple):
+        return tuple(wrap_with_proxy(e, proxy_res[idx]) for idx, e in enumerate(inner_res))
+    elif isinstance(inner_res, list):
+        return list([wrap_with_proxy(e, proxy_res[idx]) for idx, e in enumerate(inner_res)])
+    elif isinstance(inner_res, torch.Tensor):
+        return wrap_with_proxy(inner_res, proxy_res)
     else:
-        return real_out
+        return inner_res
 
 
 def proxy_call(func_overload, args, kwargs=None):
+    if kwargs is None:
+        kwargs = {}
     func = func_overload.overloadpacket
     if func_overload in CURRENT_DECOMPOSITION_TABLE:
         return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
@@ -72,48 +74,61 @@
     def unwrap_proxy(e):
         return e.proxy if isinstance(e, ProxyTensor) else e
 
+    def unwrap_elem(e):
+        if isinstance(e, ProxyTensor):
+            return e.elem
+        return e
+
     proxy_args = pytree.tree_map(unwrap_proxy, args)
     proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
 
-    proxy_out = func_overload(*proxy_args, **proxy_kwargs)
+    proxy_res = func_overload(*proxy_args, **proxy_kwargs)
 
     # Kind of a hacky way to test if an op is in-place or not
     if func.__name__[-1] == "_" and func.__name__[0] != "_":
-        args[0].proxy = proxy_out
-        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
+        args[0].proxy = proxy_res
+        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
 
-    with no_dispatch():
-        real_out = func_overload(*args, **kwargs)
+    inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
+    # Needed to sync up metadata for in-place operators that modify metadata
+    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
+        with no_dispatch():
+            func_overload(*args, **kwargs)
 
-    return wrap_output(real_out, proxy_out)
+    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
+    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
+    return wrap_output(inner_res, proxy_res)
+
 
 class ProxyTensor(torch.Tensor):
     proxy: fx.Proxy
+    elem: torch.Tensor
+
 
     @staticmethod
     def __new__(cls, elem, proxy, *, requires_grad=None):
-        # Hack to deal with super().__new__ not working for sparse tensors
-        if elem.is_sparse or requires_grad is not None:
-            if requires_grad is None:
-                requires_grad = False
-            r = torch.Tensor._make_subclass(cls, elem, requires_grad)
-        else:
-            r = super().__new__(cls, elem)  # type: ignore[call-arg]
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            elem.shape, dtype=elem.dtype, layout=elem.layout, device=elem.device,
+            requires_grad=requires_grad if requires_grad is not None else False, strides=elem.stride(),
+            storage_offset=elem.storage_offset()
+        )
+        return r
 
+    def __init__(self, elem, proxy, *, requires_grad=None):
         if elem.is_sparse:
             proxy.node.meta['tensor_meta'] = {}
         else:
-            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
-        r.proxy = proxy  # type: ignore[attr-defined]
-
-        return r
+            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
+        self.elem = elem
+        self.proxy = proxy
 
     def __deepcopy__(self, memo):
         return self.clone()
 
     def __repr__(self):
         with no_dispatch():
-            return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})"  # type: ignore[arg-type]
+            return f"ProxyTensor({self.elem}, proxy={self.proxy})"
 
     __torch_function__ = _disabled_torch_function_impl
 
@@ -156,15 +171,10 @@
 
 def dispatch_trace(
         root: Union[torch.nn.Module, Callable],
+        tracer: Tracer,
         concrete_args: Optional[Tuple[Any, ...]] = None,
-        trace_factory_functions: bool = False,
 ) -> GraphModule:
-    tracer = PythonKeyTracer()
-    if trace_factory_functions:
-        with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)):
-            graph = tracer.trace(root, concrete_args)
-    else:
-        graph = tracer.trace(root, concrete_args)
+    graph = tracer.trace(root, concrete_args)
     name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
     return GraphModule(tracer.root, graph, name)
 
@@ -179,9 +189,11 @@
         for idx, arg in enumerate(flat_args):
             if isinstance(flat_inps[idx], torch.Tensor):
                 with no_dispatch():
-                    flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=(
-                        flat_inps[idx].is_leaf and flat_inps[idx].requires_grad
-                    ))
+                    flat_args[idx] = ProxyTensor(
+                        flat_inps[idx],
+                        arg,
+                        requires_grad=(flat_inps[idx].is_leaf and flat_inps[idx].requires_grad)
+                    )
             else:
                 flat_args[idx] = flat_inps[idx]
 
@@ -205,25 +217,36 @@
         if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
             return proxy_call(func_overload, args, kwargs)
         else:
-            proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs,
+            proxy_res = self.tracer.create_proxy('call_function', func, args, kwargs,
                                                  name=self.tracer.graph._target_to_str(func.__name__))
 
-            with no_dispatch():
-                real_out = func_overload(*args, **kwargs)
+            inner_res = func_overload(*args, **kwargs)
 
-            return wrap_output(real_out, proxy_out)
+            return wrap_output(inner_res, proxy_res)
 
 
-def make_fx(f, decomposition_table=None, trace_factory_functions=True):
+def make_fx(f, decomposition_table=None, trace_factory_functions=True, use_fake=False):
     if decomposition_table is None:
         decomposition_table = {}
 
     @functools.wraps(f)
     def wrapped(*args):
-        phs = pytree.tree_map(lambda x: fx.PH, args)  # type: ignore[attr-defined]
-        with decompose(decomposition_table):
-            t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs),
-                               trace_factory_functions=trace_factory_functions)
+        phs = pytree.tree_map(lambda _: fx.PH, args)  # type: ignore[attr-defined]
+        fx_tracer = PythonKeyTracer()
+        fake_tensor_mode = FakeTensorMode() if use_fake else nullcontext()
+        proxy_mode = ProxyTorchDispatchMode(fx_tracer) if trace_factory_functions else nullcontext()
+
+        def wrap_fake(x):
+            if isinstance(x, torch.Tensor):
+                return fake_tensor_mode.from_tensor(x)  # type: ignore[attr-defined]
+
+            return x
+
+        if use_fake:  # type: ignore[attr-defined]
+            args = pytree.tree_map(wrap_fake, args)
+
+        with decompose(decomposition_table), fake_tensor_mode, proxy_mode:  # type: ignore[attr-defined]
+            t = dispatch_trace(wrap_key(f, args), tracer=fx_tracer, concrete_args=tuple(phs))
         return t
 
     return wrapped