Delete ProxyTensor wrapper subclass (#83330)

I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced.  I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."

This PR is the result.  It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API.  There is no more wrapping.  To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors.  (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)

Benefits of doing this:

* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
  we never create new tensors (except when we call the underlying func),
  so you don't have to worry about accidentally tracing them.

* No more syncing up metadata from in place operators.  In particular
  https://github.com/pytorch/pytorch/issues/81526 is mooted

* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.

* No more schlepping symbolic integers from the inner fake tensor to the
  outer proxy tensor.  If you can make a fake tensor with symbolic ints,
  you're done, nothing else to do.

To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before.  A more optimized implementation is possible.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 82997ca..64b9d94 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -175,7 +175,9 @@
     def _test(self, f, inps):
         fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
         new_inps = tree_map(_create_new_input, inps)
-        self.assertEqual(fx_f(*new_inps), f(*new_inps))
+        r1 = fx_f(*new_inps)
+        r2 = f(*new_inps)
+        self.assertEqual(r1, r2)
 
     def test_make_fx_simple(self):
         def f(x):
@@ -284,11 +286,10 @@
             self.assertTrue(is_any_sigmoid(gm))
             return torch.digamma(x)
 
-        with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
-            traced = make_fx(f2_logging_tensor)(torch.randn(3))
-            self.assertFalse(is_any_sum(traced))
-            self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
-            self.assertTrue(is_any_digamma(traced))
+        traced = make_fx(f2_logging_tensor)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+        self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
+        self.assertTrue(is_any_digamma(traced))
 
     def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
         def f(x):
@@ -514,6 +515,8 @@
         model = Foo()
 
         def f(args, params, buffers):
+            for p in params.values():
+                p.grad = None
             if not isinstance(args, Iterable):
                 args = [args]
             params_and_buffers = {**params, **buffers}