switched over to using faketensor in proxytensor (#79634)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79634
Approved by: https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 40bec21..76bd4ae 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -7,6 +7,7 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
+from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch.testing._internal.common_device_type import ops
from torch.fx.experimental.proxy_tensor import make_fx
@@ -59,7 +60,7 @@
class TestProxyTensor(TestCase):
- def test_make_fx(self, device):
+ def test_make_fx_simple(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
@@ -110,6 +111,17 @@
assert inp.grad is None
torch.testing.assert_close(traced_graph_out, f(inp))
+ def test_inplace_metadata(self):
+ def f(x):
+ x = x.clone()
+ x.unsqueeze_(-1)
+ assert x.shape[-1] == 1
+ return x
+
+ inps = [torch.randn(5)]
+ fx_f = make_fx(f)(*inps)
+ self.assertEqual(fx_f(*inps), f(*inps))
+
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
@@ -136,6 +148,7 @@
)
make_fx_failures = {
+ # unknown
xfail('allclose'),
xfail('equal'),
xfail('linalg.eigvals'),
@@ -150,6 +163,7 @@
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
+
# data-dependent control flow
xfail('cov'),
xfail('istft'),
@@ -182,35 +196,67 @@
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
- # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
+ # ???
xfail('nn.functional.ctc_loss'),
+ # Sparse tensors are not supported with faketensors for now
+ xfail('to_sparse'),
+ # segfaults
+ skip('block_diag'),
+}
+
+fake_tensor_failures = {
+ # Needs complex-value support
+ xfail('polar'),
+ xfail('complex'),
+ xfail('linalg.eig'),
+ # FakeTensor fallback doesn't work
+ xfail('linalg.matrix_power'),
+ xfail('segment_reduce', 'lengths'),
+ xfail('multinomial'),
+ xfail('mvlgamma', 'mvlgamma_p_1'),
+ xfail('mvlgamma', 'mvlgamma_p_3'),
+ xfail('mvlgamma', 'mvlgamma_p_5'),
+ xfail('cholesky'),
+ xfail('cholesky_inverse'),
+ # ASAN failures due to divide by 0
+ skip('nn.functional.nll_loss'),
}
+def _test_make_fx_helper(self, device, dtype, op, use_fake):
+ def f(args, kwargs):
+ return op.op(*args, **kwargs)
+ sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
+ new_f = None
+ for sample_input in sample_inputs_itr:
+ args = [sample_input.input] + list(sample_input.args)
+ kwargs = sample_input.kwargs
+
+ try:
+ new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
+ except DynamicOutputShapeException as e:
+ self.skipTest("Dynamic output shape operation in trace")
+
+ for arg in args:
+ if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
+ arg.uniform_(0, 1)
+ try:
+ old_out = f(args, kwargs)
+ except Exception:
+ continue
+ new_out = wrapper_set_seed(new_f, args, kwargs)
+ self.assertEqual(new_out, old_out)
+
class TestProxyTensorOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
- @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
- )
+ @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op):
+ _test_make_fx_helper(self, device, dtype, op, False)
- def f(args, kwargs):
- return op.op(*args, **kwargs)
- sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
- new_f = None
- for sample_input in sample_inputs_itr:
- args = [sample_input.input] + list(sample_input.args)
- kwargs = sample_input.kwargs
-
- new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
- for arg in args:
- if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
- arg.uniform_(0, 1)
- try:
- old_out = f(args, kwargs)
- except Exception:
- continue
- new_out = wrapper_set_seed(new_f, args, kwargs)
- self.assertEqual(new_out, old_out)
+ @ops(op_db, allowed_dtypes=(torch.float,))
+ @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
+ def test_make_fx_fake_exhaustive(self, device, dtype, op):
+ _test_make_fx_helper(self, device, dtype, op, True)