Sparse fake tensor support (#82172)
Add support for sparse fake tensors.
- The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work.
- I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations).
- I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors.
- Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors
- Meta converter is taught how to convert sparse tensors to meta
There's still a little more cleanup that needs to be done, but this is good for review.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172
Approved by: https://github.com/eellison
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 6ea0936..e989a66 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -291,7 +291,7 @@
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
- self.assertTrue(ten.device.type == 'cuda')
+ self.assertEqual(ten.device.type, 'cuda')
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a4503b1..ab33104 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -556,8 +556,8 @@
# ???
xfail('nn.functional.ctc_loss'),
- # Sparse tensors are not supported with faketensors for now
- xfail('to_sparse'),
+ # proxy tensor doesn't support sparse correctly right now
+ skip('to_sparse'),
# segfaults
skip('block_diag'),
}
diff --git a/test/test_sparse.py b/test/test_sparse.py
index e0b50e1..30bb6f3 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -9,7 +9,7 @@
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
- DeterministicGuard, first_sample
+ DeterministicGuard, first_sample, TEST_WITH_CROSSREF
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
@@ -25,6 +25,7 @@
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
floating_and_complex_types_and, integral_types, floating_types_and,
)
+from torch.utils._python_dispatch import TorchDispatchMode
if TEST_SCIPY:
import scipy.sparse
@@ -40,7 +41,53 @@
IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2"
) or (not IS_WINDOWS and CUDA11OrLater)
-class TestSparse(TestCase):
+class CrossRefSparseFakeMode(TorchDispatchMode):
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+ kwargs = kwargs or {}
+
+ def on_tensor(f):
+ def go(t):
+ if isinstance(t, torch.Tensor):
+ return f(t)
+ else:
+ return t
+ return go
+
+ # empty_like excluded for now due to sparse complex
+ # aten._to_dense.default this one is getting called with csc
+ if (
+ func not in [
+ torch.ops.aten.lift_fresh.default,
+ torch.ops.aten.empty_like.default,
+ torch.ops.aten.set_.source_Storage_storage_offset,
+ torch.ops.aten.sspaddmm.out,
+ torch.ops.aten._spdiags.default,
+ torch.ops.aten._to_dense.default
+ ]
+ and torch.Tag.dynamic_output_shape not in func.tags
+ and torch.Tag.inplace_view not in func.tags
+ ):
+ from torch._subclasses.fake_tensor import FakeTensorMode, UnsupportedFakeTensorException
+ from torch.utils._pytree import tree_map
+ try:
+ with FakeTensorMode(allow_meta=True) as fake_mode:
+ fake_args, fake_kwargs = tree_map(on_tensor(fake_mode.from_tensor), (args, kwargs))
+ fake_r = func(*fake_args, **fake_kwargs)
+ except UnsupportedFakeTensorException:
+ pass
+
+ r = func(*args, **kwargs)
+ return r
+
+class TestSparseBase(TestCase):
+ def run(self, result=None):
+ if TEST_WITH_CROSSREF:
+ with CrossRefSparseFakeMode():
+ return super().run(result)
+ else:
+ return super().run(result)
+
+class TestSparse(TestSparseBase):
def setUp(self):
TestCase.setUp(self)
@@ -1641,6 +1688,7 @@
@coalescedonoff
@dtypes(torch.double)
+ @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
def test_sparse_sum(self, device, dtype, coalesced):
def run_tests(S, td=None):
@@ -3413,6 +3461,7 @@
*[torch.bfloat16] if CUDA11OrLater and SM80OrLater else [],
*[torch.complex64] if CUDA11OrLater else [],
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
+ @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
def test_sparse_matmul(self, device, dtype, coalesced):
"""