switched over to using faketensor in proxytensor (#79634)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79634
Approved by: https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 40bec21..76bd4ae 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -7,6 +7,7 @@
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_methods_invocations import DecorateInfo
 from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
+from torch._subclasses.fake_tensor import DynamicOutputShapeException
 
 from torch.testing._internal.common_device_type import ops
 from torch.fx.experimental.proxy_tensor import make_fx
@@ -59,7 +60,7 @@
 
 
 class TestProxyTensor(TestCase):
-    def test_make_fx(self, device):
+    def test_make_fx_simple(self, device):
         def f(x):
             return torch.sin(x)
         inp = torch.randn(3)
@@ -110,6 +111,17 @@
             assert inp.grad is None
             torch.testing.assert_close(traced_graph_out, f(inp))
 
+    def test_inplace_metadata(self):
+        def f(x):
+            x = x.clone()
+            x.unsqueeze_(-1)
+            assert x.shape[-1] == 1
+            return x
+
+        inps = [torch.randn(5)]
+        fx_f = make_fx(f)(*inps)
+        self.assertEqual(fx_f(*inps), f(*inps))
+
     def test_mode_tracing_factory_function(self):
         def f(x):
             return x + torch.randn(x.shape)
@@ -136,6 +148,7 @@
         )
 
 make_fx_failures = {
+    # unknown
     xfail('allclose'),
     xfail('equal'),
     xfail('linalg.eigvals'),
@@ -150,6 +163,7 @@
     skip('nn.functional.max_unpool2d', '', device_type='cpu'),
     skip('nn.functional.max_unpool3d', '', device_type='cpu'),
     skip('linalg.lstsq'),  # flaky, probably just a precision issue
+
     # data-dependent control flow
     xfail('cov'),
     xfail('istft'),
@@ -182,35 +196,67 @@
     # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
     xfail('sparse.sampled_addmm'),
 
-    # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
+    # ???
     xfail('nn.functional.ctc_loss'),
+    # Sparse tensors are not supported with faketensors for now
+    xfail('to_sparse'),
+    # segfaults
+    skip('block_diag'),
+}
+
+fake_tensor_failures = {
+    # Needs complex-value support
+    xfail('polar'),
+    xfail('complex'),
+    xfail('linalg.eig'),
+    # FakeTensor fallback doesn't work
+    xfail('linalg.matrix_power'),
+    xfail('segment_reduce', 'lengths'),
+    xfail('multinomial'),
+    xfail('mvlgamma', 'mvlgamma_p_1'),
+    xfail('mvlgamma', 'mvlgamma_p_3'),
+    xfail('mvlgamma', 'mvlgamma_p_5'),
+    xfail('cholesky'),
+    xfail('cholesky_inverse'),
+    # ASAN failures due to divide by 0
+    skip('nn.functional.nll_loss'),
 }
 
 
+def _test_make_fx_helper(self, device, dtype, op, use_fake):
+    def f(args, kwargs):
+        return op.op(*args, **kwargs)
+    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
+    new_f = None
+    for sample_input in sample_inputs_itr:
+        args = [sample_input.input] + list(sample_input.args)
+        kwargs = sample_input.kwargs
+
+        try:
+            new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
+        except DynamicOutputShapeException as e:
+            self.skipTest("Dynamic output shape operation in trace")
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
+                arg.uniform_(0, 1)
+        try:
+            old_out = f(args, kwargs)
+        except Exception:
+            continue
+        new_out = wrapper_set_seed(new_f, args, kwargs)
+        self.assertEqual(new_out, old_out)
+
 class TestProxyTensorOpInfo(TestCase):
     @ops(op_db, allowed_dtypes=(torch.float,))
-    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
-             )
+    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
     def test_make_fx_exhaustive(self, device, dtype, op):
+        _test_make_fx_helper(self, device, dtype, op, False)
 
-        def f(args, kwargs):
-            return op.op(*args, **kwargs)
-        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
-        new_f = None
-        for sample_input in sample_inputs_itr:
-            args = [sample_input.input] + list(sample_input.args)
-            kwargs = sample_input.kwargs
-
-            new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
-            for arg in args:
-                if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
-                    arg.uniform_(0, 1)
-            try:
-                old_out = f(args, kwargs)
-            except Exception:
-                continue
-            new_out = wrapper_set_seed(new_f, args, kwargs)
-            self.assertEqual(new_out, old_out)
+    @ops(op_db, allowed_dtypes=(torch.float,))
+    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
+    def test_make_fx_fake_exhaustive(self, device, dtype, op):
+        _test_make_fx_helper(self, device, dtype, op, True)