Delete ProxyTensor wrapper subclass (#83330)
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced. I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."
This PR is the result. It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API. There is no more wrapping. To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors. (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)
Benefits of doing this:
* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
we never create new tensors (except when we call the underlying func),
so you don't have to worry about accidentally tracing them.
* No more syncing up metadata from in place operators. In particular
https://github.com/pytorch/pytorch/issues/81526 is mooted
* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.
* No more schlepping symbolic integers from the inner fake tensor to the
outer proxy tensor. If you can make a fake tensor with symbolic ints,
you're done, nothing else to do.
To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before. A more optimized implementation is possible.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 82997ca..64b9d94 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -175,7 +175,9 @@
def _test(self, f, inps):
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
- self.assertEqual(fx_f(*new_inps), f(*new_inps))
+ r1 = fx_f(*new_inps)
+ r2 = f(*new_inps)
+ self.assertEqual(r1, r2)
def test_make_fx_simple(self):
def f(x):
@@ -284,11 +286,10 @@
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
- with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
- traced = make_fx(f2_logging_tensor)(torch.randn(3))
- self.assertFalse(is_any_sum(traced))
- self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
- self.assertTrue(is_any_digamma(traced))
+ traced = make_fx(f2_logging_tensor)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
+ self.assertTrue(is_any_digamma(traced))
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
def f(x):
@@ -514,6 +515,8 @@
model = Foo()
def f(args, params, buffers):
+ for p in params.values():
+ p.grad = None
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}